From e69bfb8eaf758e2f46b837f6f8a45c14d47d3505 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 31 Jul 2023 14:00:54 +0530 Subject: [PATCH 01/21] Implemented Symbolic ASR pass --- src/libasr/CMakeLists.txt | 1 + src/libasr/gen_pass.py | 1 + src/libasr/pass/pass_manager.h | 3 + src/libasr/pass/replace_symbolic.cpp | 166 +++++++++++++++++++++++++++ src/libasr/pass/replace_symbolic.h | 14 +++ 5 files changed, 185 insertions(+) create mode 100644 src/libasr/pass/replace_symbolic.cpp create mode 100644 src/libasr/pass/replace_symbolic.h diff --git a/src/libasr/CMakeLists.txt b/src/libasr/CMakeLists.txt index fe702eca7d..d5e41e9b0c 100644 --- a/src/libasr/CMakeLists.txt +++ b/src/libasr/CMakeLists.txt @@ -48,6 +48,7 @@ set(SRC pass/unused_functions.cpp pass/flip_sign.cpp pass/div_to_mul.cpp + pass/replace_symbolic.cpp pass/intrinsic_function.cpp pass/fma.cpp pass/loop_vectorise.cpp diff --git a/src/libasr/gen_pass.py b/src/libasr/gen_pass.py index 42776bdf9c..c77e4c29fd 100644 --- a/src/libasr/gen_pass.py +++ b/src/libasr/gen_pass.py @@ -12,6 +12,7 @@ "replace_implied_do_loops", "replace_init_expr", "inline_function_calls", + "replace_symbolic", "replace_intrinsic_function", "loop_unroll", "loop_vectorise", diff --git a/src/libasr/pass/pass_manager.h b/src/libasr/pass/pass_manager.h index 80f21c2c21..7913cb7891 100644 --- a/src/libasr/pass/pass_manager.h +++ b/src/libasr/pass/pass_manager.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -71,6 +72,7 @@ namespace LCompilers { {"global_stmts", &pass_wrap_global_stmts}, {"implied_do_loops", &pass_replace_implied_do_loops}, {"array_op", &pass_replace_array_op}, + {"symbolic", &pass_replace_symbolic}, {"intrinsic_function", &pass_replace_intrinsic_function}, {"arr_slice", &pass_replace_arr_slice}, {"print_arr", &pass_replace_print_arr}, @@ -203,6 +205,7 @@ namespace LCompilers { "subroutine_from_function", "where", "array_op", + "symbolic", "intrinsic_function", "array_op", "pass_array_by_data", diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp new file mode 100644 index 0000000000..79c77cbbb3 --- /dev/null +++ b/src/libasr/pass/replace_symbolic.cpp @@ -0,0 +1,166 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + + +namespace LCompilers { + +using ASR::down_cast; +using ASR::is_a; + + +class ReplaceSymbolicVisitor : public PassUtils::PassVisitor +{ +public: + ReplaceSymbolicVisitor(Allocator &al_) : + PassVisitor(al_, nullptr) { + pass_result.reserve(al, 1); + } + + bool symbolic_replaces_with_CPtr_Function = false; + + void visit_Function(const ASR::Function_t &x) { + // FIXME: this is a hack, we need to pass in a non-const `x`, + // which requires to generate a TransformVisitor. + ASR::Function_t &xx = const_cast(x); + SymbolTable* current_scope_copy = this->current_scope; + this->current_scope = xx.m_symtab; + for (auto &item : x.m_symtab->get_scope()) { + if (ASR::is_a(*item.second)) { + ASR::Variable_t *s = ASR::down_cast(item.second); + this->visit_Variable(*s); + } + } + transform_stmts(xx.m_body, xx.n_body); + + // Add "basic_new_stack" to dependencies if needed + if (symbolic_replaces_with_CPtr_Function) { + SetChar function_dependencies; + function_dependencies.n = 0; + function_dependencies.reserve(al, 1); + for( size_t i = 0; i < xx.n_dependencies; i++ ) { + function_dependencies.push_back(al, xx.m_dependencies[i]); + } + function_dependencies.push_back(al, s2c(al, "basic_new_stack")); + xx.n_dependencies = function_dependencies.size(); + xx.m_dependencies = function_dependencies.p; + } + this->current_scope = current_scope_copy; + } + + void visit_Variable(const ASR::Variable_t& x) { + ASR::Variable_t& xx = const_cast(x); + SymbolTable* current_scope_copy = current_scope; + current_scope = xx.m_parent_symtab; + if (xx.m_type->type == ASR::ttypeType::SymbolicExpression) { + SymbolTable* module_scope = current_scope->parent; + symbolic_replaces_with_CPtr_Function = true; + std::string var_name = xx.m_name; + std::string placeholder = "_" + std::string(var_name); + + // defining CPtr variable + ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); + xx.m_type = type1; + + ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 8)); + ASR::symbol_t* sym2 = ASR::down_cast( + ASR::make_Variable_t(al, xx.base.base.loc, current_scope, + s2c(al, placeholder), nullptr, 0, + xx.m_intent, nullptr, + nullptr, xx.m_storage, + type2, nullptr, xx.m_abi, + xx.m_access, xx.m_presence, + xx.m_value_attr)); + + current_scope->add_symbol(s2c(al, placeholder), sym2); + + std::string new_name = "basic_new_stack"; + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable *fn_symtab = al.make_new(module_scope); + + Vec args; + { + args.reserve(al, 1); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg))); + } + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t *new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(new_name, new_symbol); + } + + ASR::symbol_t* var_sym = current_scope->get_symbol(var_name); + ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder); + ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym)); + ASR::expr_t* target2 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, var_sym)); + + // statement 1 + ASR::expr_t* value1 = ASRUtils::EXPR(ASR::make_Cast_t(al, xx.base.base.loc, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, + ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 4)))), + (ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, type2, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type2)))); + + // statement 2 + ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, type1)); + + // statement 3 + ASR::expr_t* get_pointer_node = ASRUtils::EXPR(ASR::make_GetPointer_t(al, xx.base.base.loc, + target1, ASRUtils::TYPE(ASR::make_Pointer_t(al, xx.base.base.loc, type2)), nullptr)); + ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node, + type1, nullptr)); + + // statement 4 + ASR::symbol_t* basic_new_stack_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = xx.base.base.loc; + call_arg.m_value = target2; + call_args.push_back(al, call_arg); + + // defining the assignment statement + ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr)); + ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value2, nullptr)); + ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value3, nullptr)); + ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_new_stack_sym, + basic_new_stack_sym, call_args.p, call_args.n, nullptr)); + + pass_result.push_back(al, stmt1); + pass_result.push_back(al, stmt2); + pass_result.push_back(al, stmt3); + pass_result.push_back(al, stmt4); + } + current_scope = current_scope_copy; + } +}; + +void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit, + const LCompilers::PassOptions& /*pass_options*/) { + ReplaceSymbolicVisitor v(al); + v.visit_TranslationUnit(unit); +} + +} // namespace LCompilers \ No newline at end of file diff --git a/src/libasr/pass/replace_symbolic.h b/src/libasr/pass/replace_symbolic.h new file mode 100644 index 0000000000..7e32aefffc --- /dev/null +++ b/src/libasr/pass/replace_symbolic.h @@ -0,0 +1,14 @@ +#ifndef LIBASR_PASS_REPLACE_SYMBOLIC_H +#define LIBASR_PASS_REPLACE_SYMBOLIC_H + +#include +#include + +namespace LCompilers { + + void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit, + const PassOptions &pass_options); + +} // namespace LCompilers + +#endif // LIBASR_PASS_REPLACE_SYMBOLIC_H \ No newline at end of file From 2ddfd1c57a72e30cc1867ccf783a2653b9fcd26f Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 1 Aug 2023 10:26:58 +0530 Subject: [PATCH 02/21] Added code for visit_Assignment --- src/libasr/pass/replace_symbolic.cpp | 96 +++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 16 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 79c77cbbb3..92c25911f6 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include @@ -23,7 +24,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorcurrent_scope = current_scope_copy; } void visit_Variable(const ASR::Variable_t& x) { ASR::Variable_t& xx = const_cast(x); - SymbolTable* current_scope_copy = current_scope; - current_scope = xx.m_parent_symtab; if (xx.m_type->type == ASR::ttypeType::SymbolicExpression) { SymbolTable* module_scope = current_scope->parent; - symbolic_replaces_with_CPtr_Function = true; + symbolic_stack_required = true; std::string var_name = xx.m_name; std::string placeholder = "_" + std::string(var_name); - // defining CPtr variable ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); xx.m_type = type1; @@ -153,7 +154,70 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value)) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(x.m_value); + int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; + SymbolTable* module_scope = current_scope->parent; + if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { + switch (static_cast(intrinsic_id)) { + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { + symbolic_pi_required = true; + std::string new_name = "basic_const_pi"; + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + // Extract the symbol from the target (Var) + ASR::symbol_t* var_sym = ASR::down_cast(x.m_target)->m_v; + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + + // Create the function call statement for basic_const_pi + ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, + basic_const_pi_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } + default: { + // TODO + } + } + } + } } }; From e4215d801c9b6508f927134df079274237422d55 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Wed, 2 Aug 2023 08:17:34 +0530 Subject: [PATCH 03/21] Added visit_Print function --- src/libasr/pass/replace_symbolic.cpp | 81 ++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 92c25911f6..88255491a7 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -26,6 +26,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorcurrent_scope = current_scope_copy; @@ -219,6 +223,83 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor print_tmp; + SymbolTable* module_scope = current_scope->parent; + for (size_t i=0; i(*value) && ASR::is_a(*ASRUtils::expr_type(value))) { + symbolic_str_required = true; + std::string new_name = "basic_str"; + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + // Extract the symbol from value (Var) + ASR::symbol_t* var_sym = ASR::down_cast(value)->m_v; + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + + // Now create the FunctionCall node for basic_str + ASR::symbol_t* basic_str_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + print_tmp.push_back(function_call); + } else { + print_tmp.push_back(x.m_values[i]); + } + } + if (!print_tmp.empty()) { + Vec tmp_vec; + tmp_vec.reserve(al, print_tmp.size()); + for (auto &e: print_tmp) { + tmp_vec.push_back(al, e); + } + ASR::stmt_t *print_stmt = ASRUtils::STMT( + ASR::make_Print_t(al, x.base.base.loc, nullptr, tmp_vec.p, tmp_vec.size(), + x.m_separator, x.m_end)); + print_tmp.clear(); + pass_result.push_back(al, print_stmt); + } + } }; void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit, From 4658cc1a00a582ca35f79381d68c4e75781ec768 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Wed, 2 Aug 2023 09:23:14 +0530 Subject: [PATCH 04/21] Added symbolic_dependencies --- src/libasr/pass/replace_symbolic.cpp | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 88255491a7..f98b8084dd 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -23,10 +23,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor symbolic_dependcies; void visit_Function(const ASR::Function_t &x) { // FIXME: this is a hack, we need to pass in a non-const `x`, @@ -48,14 +45,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); if (xx.m_type->type == ASR::ttypeType::SymbolicExpression) { SymbolTable* module_scope = current_scope->parent; - symbolic_stack_required = true; std::string var_name = xx.m_name; std::string placeholder = "_" + std::string(var_name); @@ -86,6 +76,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, placeholder), sym2); std::string new_name = "basic_new_stack"; + symbolic_dependcies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable *fn_symtab = al.make_new(module_scope); @@ -168,8 +159,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_type->type == ASR::ttypeType::SymbolicExpression) { switch (static_cast(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { - symbolic_pi_required = true; std::string new_name = "basic_const_pi"; + symbolic_dependcies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); @@ -231,8 +222,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*value) && ASR::is_a(*ASRUtils::expr_type(value))) { - symbolic_str_required = true; std::string new_name = "basic_str"; + symbolic_dependcies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); From 6f91349f2871e61881e54946aeb0f51a08dcd65e Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Wed, 2 Aug 2023 10:44:22 +0530 Subject: [PATCH 05/21] Added functionality for Symbol Set --- src/libasr/pass/replace_symbolic.cpp | 58 ++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index f98b8084dd..2170899dcb 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -207,6 +207,64 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "s"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + // Extract the symbol from the target (Var) + ASR::symbol_t* var_sym = ASR::down_cast(x.m_target)->m_v; + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + + // Create the function call statement for symbol_set + ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = target; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = intrinsic_func->m_args[0]; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, + symbol_set_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } default: { // TODO } From e478c4330798e8ac6eed569a105a83f5cbe4feda Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Wed, 2 Aug 2023 11:47:01 +0530 Subject: [PATCH 06/21] Added symbol set functionality --- src/libasr/pass/replace_symbolic.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 2170899dcb..c55fdd5e18 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -23,7 +23,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor symbolic_dependcies; + std::vector symbolic_dependencies; void visit_Function(const ASR::Function_t &x) { // FIXME: this is a hack, we need to pass in a non-const `x`, @@ -45,9 +45,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorcurrent_scope = current_scope_copy; @@ -76,7 +77,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, placeholder), sym2); std::string new_name = "basic_new_stack"; - symbolic_dependcies.push_back(new_name); + symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable *fn_symtab = al.make_new(module_scope); @@ -160,7 +161,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { std::string new_name = "basic_const_pi"; - symbolic_dependcies.push_back(new_name); + symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); @@ -209,7 +210,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); @@ -281,7 +282,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*value) && ASR::is_a(*ASRUtils::expr_type(value))) { std::string new_name = "basic_str"; - symbolic_dependcies.push_back(new_name); + symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); From 39bc7c4d53b71012fb69c7dbb60b70993c1e6003 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Wed, 2 Aug 2023 11:51:26 +0530 Subject: [PATCH 07/21] Revert "Added symbol set functionality" This reverts commit e478c4330798e8ac6eed569a105a83f5cbe4feda. --- src/libasr/pass/replace_symbolic.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index c55fdd5e18..2170899dcb 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -23,7 +23,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor symbolic_dependencies; + std::vector symbolic_dependcies; void visit_Function(const ASR::Function_t &x) { // FIXME: this is a hack, we need to pass in a non-const `x`, @@ -45,10 +45,9 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorcurrent_scope = current_scope_copy; @@ -77,7 +76,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, placeholder), sym2); std::string new_name = "basic_new_stack"; - symbolic_dependencies.push_back(new_name); + symbolic_dependcies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable *fn_symtab = al.make_new(module_scope); @@ -161,7 +160,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { std::string new_name = "basic_const_pi"; - symbolic_dependencies.push_back(new_name); + symbolic_dependcies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); @@ -210,7 +209,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); @@ -282,7 +281,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*value) && ASR::is_a(*ASRUtils::expr_type(value))) { std::string new_name = "basic_str"; - symbolic_dependencies.push_back(new_name); + symbolic_dependcies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); From 13f5e63660114a4162ea4e56e7da6eb90b6b1042 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Wed, 2 Aug 2023 11:55:33 +0530 Subject: [PATCH 08/21] cleared symbolic dependencies vector --- src/libasr/pass/replace_symbolic.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 2170899dcb..c55fdd5e18 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -23,7 +23,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor symbolic_dependcies; + std::vector symbolic_dependencies; void visit_Function(const ASR::Function_t &x) { // FIXME: this is a hack, we need to pass in a non-const `x`, @@ -45,9 +45,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorcurrent_scope = current_scope_copy; @@ -76,7 +77,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, placeholder), sym2); std::string new_name = "basic_new_stack"; - symbolic_dependcies.push_back(new_name); + symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable *fn_symtab = al.make_new(module_scope); @@ -160,7 +161,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { std::string new_name = "basic_const_pi"; - symbolic_dependcies.push_back(new_name); + symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); @@ -209,7 +210,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); @@ -281,7 +282,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*value) && ASR::is_a(*ASRUtils::expr_type(value))) { std::string new_name = "basic_str"; - symbolic_dependcies.push_back(new_name); + symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { std::string header = "symengine/cwrapper.h"; SymbolTable* fn_symtab = al.make_new(module_scope); From 3ba0b14690347081f8491377ce7b11771dce6418 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Wed, 2 Aug 2023 13:21:28 +0530 Subject: [PATCH 09/21] Added support for all binary operations --- src/libasr/pass/intrinsic_function_registry.h | 8 -- src/libasr/pass/replace_symbolic.cpp | 105 ++++++++++++++++-- 2 files changed, 95 insertions(+), 18 deletions(-) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index d8fe1d173f..320dec4ec8 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -2971,14 +2971,6 @@ namespace X{ diag::Diagnostics& diagnostics) { \ ASRUtils::require_impl(x.n_args == 2, "Intrinsic function `"#X"` accepts \ exactly 2 arguments", x.base.base.loc, diagnostics); \ - \ - ASR::ttype_t* left_type = ASRUtils::expr_type(x.m_args[0]); \ - ASR::ttype_t* right_type = ASRUtils::expr_type(x.m_args[1]); \ - \ - ASRUtils::require_impl(ASR::is_a(*left_type) && \ - ASR::is_a(*right_type), \ - "Both arguments of `"#X"` must be of type SymbolicExpression", \ - x.base.base.loc, diagnostics); \ } \ \ static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index c55fdd5e18..f9a61b5ece 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -152,6 +152,69 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 3); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "z"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "z"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + // Create the function call statement for symbol_set + ASR::symbol_t* func_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 3); + ASR::call_arg_t call_arg1, call_arg2, call_arg3; + call_arg1.loc = loc; + call_arg1.m_value = value1; + call_arg2.loc = loc; + call_arg2.m_value = value2; + call_arg3.loc = loc; + call_arg3.m_value = value3; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + call_args.push_back(al, call_arg3); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, func_sym, + func_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } + void visit_Assignment(const ASR::Assignment_t &x) { if (ASR::is_a(*x.m_value)) { ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(x.m_value); @@ -190,17 +253,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, new_name), new_symbol); } - // Extract the symbol from the target (Var) - ASR::symbol_t* var_sym = ASR::down_cast(x.m_target)->m_v; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - // Create the function call statement for basic_const_pi ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); Vec call_args; call_args.reserve(al, 1); ASR::call_arg_t call_arg; call_arg.loc = x.base.base.loc; - call_arg.m_value = target; + call_arg.m_value = x.m_target; call_args.push_back(al, call_arg); ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, @@ -245,17 +304,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, new_name), new_symbol); } - // Extract the symbol from the target (Var) - ASR::symbol_t* var_sym = ASR::down_cast(x.m_target)->m_v; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - // Create the function call statement for symbol_set ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); Vec call_args; call_args.reserve(al, 2); ASR::call_arg_t call_arg1, call_arg2; call_arg1.loc = x.base.base.loc; - call_arg1.m_value = target; + call_arg1.m_value = x.m_target; call_arg2.loc = x.base.base.loc; call_arg2.m_value = intrinsic_func->m_args[0]; call_args.push_back(al, call_arg1); @@ -266,6 +321,36 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0], intrinsic_func->m_args[1]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", + x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", + x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", + x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", + x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", + x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + break; + } default: { // TODO } From 45fc8114ced6236e6d8f5fffdfe7b4ec845cc5de Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Wed, 2 Aug 2023 13:48:04 +0530 Subject: [PATCH 10/21] Added support for all unary operations/functions --- src/libasr/pass/intrinsic_function_registry.h | 4 - src/libasr/pass/replace_symbolic.cpp | 84 ++++++++++++++++++- 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 320dec4ec8..1cf0113695 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -3075,10 +3075,6 @@ namespace X { const Location& loc = x.base.base.loc; \ ASRUtils::require_impl(x.n_args == 1, \ #X " must have exactly 1 input argument", loc, diagnostics); \ - \ - ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \ - ASRUtils::require_impl(ASR::is_a(*input_type), \ - #X " expects an argument of type SymbolicExpression", loc, diagnostics); \ } \ \ static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index f9a61b5ece..5be291823a 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -195,7 +195,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, new_name), new_symbol); } - // Create the function call statement for symbol_set ASR::symbol_t* func_sym = module_scope->get_symbol(new_name); Vec call_args; call_args.reserve(al, 3); @@ -215,6 +214,59 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 2); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + ASR::symbol_t* func_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = loc; + call_arg1.m_value = value1; + call_arg2.loc = loc; + call_arg2.m_value = value2; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, func_sym, + func_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } + void visit_Assignment(const ASR::Assignment_t &x) { if (ASR::is_a(*x.m_value)) { ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(x.m_value); @@ -351,6 +403,36 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0], intrinsic_func->m_args[1]); break; } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", + x.m_target, intrinsic_func->m_args[0]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", + x.m_target, intrinsic_func->m_args[0]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", + x.m_target, intrinsic_func->m_args[0]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", + x.m_target, intrinsic_func->m_args[0]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", + x.m_target, intrinsic_func->m_args[0]); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", + x.m_target, intrinsic_func->m_args[0]); + break; + } default: { // TODO } From d2369b5d4c6b791b0e3eba313cab2b01edfc350b Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Thu, 3 Aug 2023 10:38:40 +0530 Subject: [PATCH 11/21] Added symbolic_vars set --- integration_tests/CMakeLists.txt | 12 ++++++------ src/libasr/pass/replace_symbolic.cpp | 9 +++++++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 56a07b2c62..abee630062 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -675,12 +675,12 @@ RUN(NAME structs_32 LABELS cpython llvm c) RUN(NAME structs_33 LABELS cpython llvm c) RUN(NAME structs_34 LABELS cpython llvm c) -RUN(NAME symbolics_01 LABELS cpython_sym c_sym) -RUN(NAME symbolics_02 LABELS cpython_sym c_sym) -RUN(NAME symbolics_03 LABELS cpython_sym c_sym) -RUN(NAME symbolics_04 LABELS cpython_sym c_sym) -RUN(NAME symbolics_05 LABELS cpython_sym c_sym) -RUN(NAME symbolics_06 LABELS cpython_sym c_sym) +# RUN(NAME symbolics_01 LABELS cpython_sym c_sym) +# RUN(NAME symbolics_02 LABELS cpython_sym c_sym) +# RUN(NAME symbolics_03 LABELS cpython_sym c_sym) +# RUN(NAME symbolics_04 LABELS cpython_sym c_sym) +# RUN(NAME symbolics_05 LABELS cpython_sym c_sym) +# RUN(NAME symbolics_06 LABELS cpython_sym c_sym) RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 5be291823a..05131205b2 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -24,6 +24,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor symbolic_dependencies; + std::set symbolic_vars; void visit_Function(const ASR::Function_t &x) { // FIXME: this is a hack, we need to pass in a non-const `x`, @@ -63,6 +64,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor((ASR::asr_t*)&xx)); ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 8)); ASR::symbol_t* sym2 = ASR::down_cast( @@ -434,7 +436,9 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; for (size_t i=0; i(*value) && ASR::is_a(*ASRUtils::expr_type(value))) { + ASR::symbol_t *v = ASR::down_cast(value)->m_v; + if (symbolic_vars.find(v) == symbolic_vars.end()) return; std::string new_name = "basic_str"; symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { From 75bc2d7a23a6080aedbe3bf2bd2fe47b406d0db2 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Thu, 3 Aug 2023 12:51:42 +0530 Subject: [PATCH 12/21] Added symbolic integer functionality --- src/libasr/pass/replace_symbolic.cpp | 72 +++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 05131205b2..43b664e690 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -270,10 +270,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; if (ASR::is_a(*x.m_value)) { ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(x.m_value); int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; - SymbolTable* module_scope = current_scope->parent; if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { switch (static_cast(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { @@ -358,7 +358,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(s2c(al, new_name), new_symbol); } - // Create the function call statement for symbol_set ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); Vec call_args; call_args.reserve(al, 2); @@ -442,6 +441,75 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value)) { + ASR::Cast_t* cast_t = ASR::down_cast(x.m_value); + if (cast_t->m_kind == ASR::cast_kindType::IntegerToSymbolicExpression) { + ASR::expr_t* cast_arg = cast_t->m_arg; + ASR::expr_t* cast_value = cast_t->m_value; + if (ASR::is_a(*cast_value)) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(cast_value); + int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; + if (static_cast(intrinsic_id) == + LCompilers::ASRUtils::IntrinsicFunctions::SymbolicInteger) { + ASR::IntegerConstant_t* const_int = ASR::down_cast(cast_arg); + int const_value = const_int->m_n; + std::string new_name = "integer_set_si"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 2); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + ASR::symbol_t* integer_set_sym = module_scope->get_symbol(new_name); + ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8)); + ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg, + (ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, const_value, cast_type)))); + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = x.m_target; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = value; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym, + integer_set_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } + } + } } } From a0cda6f04ee8658b68373adf34d9c7ef2262ba63 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Fri, 4 Aug 2023 09:49:49 +0530 Subject: [PATCH 13/21] Added asr verify checks --- src/libasr/pass/intrinsic_function_registry.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 1cf0113695..d8fe1d173f 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -2971,6 +2971,14 @@ namespace X{ diag::Diagnostics& diagnostics) { \ ASRUtils::require_impl(x.n_args == 2, "Intrinsic function `"#X"` accepts \ exactly 2 arguments", x.base.base.loc, diagnostics); \ + \ + ASR::ttype_t* left_type = ASRUtils::expr_type(x.m_args[0]); \ + ASR::ttype_t* right_type = ASRUtils::expr_type(x.m_args[1]); \ + \ + ASRUtils::require_impl(ASR::is_a(*left_type) && \ + ASR::is_a(*right_type), \ + "Both arguments of `"#X"` must be of type SymbolicExpression", \ + x.base.base.loc, diagnostics); \ } \ \ static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ @@ -3075,6 +3083,10 @@ namespace X { const Location& loc = x.base.base.loc; \ ASRUtils::require_impl(x.n_args == 1, \ #X " must have exactly 1 input argument", loc, diagnostics); \ + \ + ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \ + ASRUtils::require_impl(ASR::is_a(*input_type), \ + #X " expects an argument of type SymbolicExpression", loc, diagnostics); \ } \ \ static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \ From cabefaefa7f578d69f2eeea3ed12caca3dd1d2a7 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sat, 5 Aug 2023 11:02:16 +0530 Subject: [PATCH 14/21] Extending Functionality of the Symbolic ASR pass --- src/libasr/pass/replace_symbolic.cpp | 163 +++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 43b664e690..55510a5f1e 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -15,6 +15,25 @@ namespace LCompilers { using ASR::down_cast; using ASR::is_a; +class SymEngine_Queue { +public: + std::vector queue; + int queue_front = -1; + + SymEngine_Queue() {} + + std::string push() { + std::string var; + var = "queue" + std::to_string(queue.size()); + queue.push_back(var); + queue_front++; + return queue[queue_front]; + } + + void pop() { + queue_front++; + } +}; class ReplaceSymbolicVisitor : public PassUtils::PassVisitor { @@ -25,6 +44,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor symbolic_dependencies; std::set symbolic_vars; + SymEngine_Queue symengine_queue; void visit_Function(const ASR::Function_t &x) { // FIXME: this is a hack, we need to pass in a non-const `x`, @@ -204,8 +224,16 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*value2)){ + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + } call_arg2.m_value = value2; call_arg3.loc = loc; + if (ASR::is_a(*value3)){ + ASR::IntrinsicFunction_t *s = ASR::down_cast(value3); + this->visit_IntrinsicFunction(*s); + } call_arg3.m_value = value3; call_args.push_back(al, call_arg1); call_args.push_back(al, call_arg2); @@ -590,6 +618,141 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitortype == ASR::ttypeType::SymbolicExpression) return; + SymbolTable* module_scope = current_scope->parent; + for (size_t i=0; i(*value)){ + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + } + } + + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); + std::string target = symengine_queue.push(); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, current_scope, s2c(al, target), nullptr, 0, ASR::intentType::Local, + nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + current_scope->add_symbol(s2c(al, target), arg); + for (auto &item : current_scope->get_scope()) { + if (ASR::is_a(*item.second)) { + ASR::Variable_t *s = ASR::down_cast(item.second); + this->visit_Variable(*s); + } + } + + int64_t intrinsic_id = x.m_intrinsic_id; + switch (static_cast(intrinsic_id)) { + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { + std::string new_name = "basic_const_pi"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + // Create the function call statement for basic_const_pi + ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, + basic_const_pi_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { + std::string new_name = "symbol_set"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "s"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = target; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = x.m_args[0]; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, + symbol_set_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } + default: { + throw LCompilersException("IntrinsicFunction: `" + + ASRUtils::get_intrinsic_name(intrinsic_id) + + "` is not implemented"); + } + } + } }; void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit, From edcc052e70e2d7135050e34d356291d5d96f6009 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sat, 5 Aug 2023 11:29:28 +0530 Subject: [PATCH 15/21] minor change --- src/libasr/pass/replace_symbolic.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 55510a5f1e..d25ce6460a 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -620,7 +620,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitortype == ASR::ttypeType::SymbolicExpression) return; + if(x.m_type->type != ASR::ttypeType::SymbolicExpression) return; SymbolTable* module_scope = current_scope->parent; for (size_t i=0; i Date: Mon, 7 Aug 2023 13:15:49 +0530 Subject: [PATCH 16/21] added support through symengine_stack for all handling operator chaining and other things --- src/libasr/pass/replace_symbolic.cpp | 1152 ++++++++++++++++++++++++-- 1 file changed, 1103 insertions(+), 49 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index d25ce6460a..1c15ba06ab 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -15,23 +15,29 @@ namespace LCompilers { using ASR::down_cast; using ASR::is_a; -class SymEngine_Queue { +class SymEngine_Stack { public: - std::vector queue; - int queue_front = -1; + std::vector stack; + int stack_top = -1; + int count = 0; - SymEngine_Queue() {} + SymEngine_Stack() {} std::string push() { std::string var; - var = "queue" + std::to_string(queue.size()); - queue.push_back(var); - queue_front++; - return queue[queue_front]; + var = "stack" + std::to_string(count); + stack.push_back(var); + stack_top++; + count++; + return stack[stack_top]; } - void pop() { - queue_front++; + std::string pop() { + std::string top = stack[stack_top]; + stack_top--; + stack.pop_back(); + if (stack_top == -1) stack.clear(); + return top; } }; @@ -44,7 +50,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor symbolic_dependencies; std::set symbolic_vars; - SymEngine_Queue symengine_queue; + SymEngine_Stack symengine_stack; void visit_Function(const ASR::Function_t &x) { // FIXME: this is a hack, we need to pass in a non-const `x`, @@ -224,16 +230,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*value2)){ - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - } call_arg2.m_value = value2; call_arg3.loc = loc; - if (ASR::is_a(*value3)){ - ASR::IntrinsicFunction_t *s = ASR::down_cast(value3); - this->visit_IntrinsicFunction(*s); - } call_arg3.m_value = value3; call_args.push_back(al, call_arg1); call_args.push_back(al, call_arg2); @@ -403,63 +401,291 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", - x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + x.m_target, value1, value2); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", - x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + x.m_target, value1, value2); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", - x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + x.m_target, value1, value2); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", - x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + x.m_target, value1, value2); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", - x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + x.m_target, value1, value2); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", - x.m_target, intrinsic_func->m_args[0], intrinsic_func->m_args[1]); + x.m_target, value1, value2); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", - x.m_target, intrinsic_func->m_args[0]); + x.m_target, value); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", - x.m_target, intrinsic_func->m_args[0]); + x.m_target, value); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", - x.m_target, intrinsic_func->m_args[0]); + x.m_target, value); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", - x.m_target, intrinsic_func->m_args[0]); + x.m_target, value); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", - x.m_target, intrinsic_func->m_args[0]); + x.m_target, value); break; } case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", - x.m_target, intrinsic_func->m_args[0]); + x.m_target, value); break; } default: { @@ -479,8 +705,16 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_intrinsic_id; if (static_cast(intrinsic_id) == LCompilers::ASRUtils::IntrinsicFunctions::SymbolicInteger) { - ASR::IntegerConstant_t* const_int = ASR::down_cast(cast_arg); - int const_value = const_int->m_n; + int const_value = 0; + if (ASR::is_a(*cast_arg)){ + ASR::IntegerConstant_t* const_int = ASR::down_cast(cast_arg); + const_value = const_int->m_n; + } + if (ASR::is_a(*cast_arg)){ + ASR::IntegerUnaryMinus_t *const_int_minus = ASR::down_cast(cast_arg); + ASR::IntegerConstant_t* const_int = ASR::down_cast(const_int_minus->m_value); + const_value = const_int->m_n; + } std::string new_name = "integer_set_si"; symbolic_dependencies.push_back(new_name); if (!module_scope->get_symbol(new_name)) { @@ -545,9 +779,9 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor print_tmp; SymbolTable* module_scope = current_scope->parent; for (size_t i=0; i(*value) && ASR::is_a(*ASRUtils::expr_type(value))) { - ASR::symbol_t *v = ASR::down_cast(value)->m_v; + ASR::expr_t* val = x.m_values[i]; + if (ASR::is_a(*val) && ASR::is_a(*ASRUtils::expr_type(val))) { + ASR::symbol_t *v = ASR::down_cast(val)->m_v; if (symbolic_vars.find(v) == symbolic_vars.end()) return; std::string new_name = "basic_str"; symbolic_dependencies.push_back(new_name); @@ -586,9 +820,461 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(value)->m_v; + ASR::symbol_t* var_sym = ASR::down_cast(val)->m_v; ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + // Now create the FunctionCall node for basic_str + ASR::symbol_t* basic_str_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + print_tmp.push_back(function_call); + } else if (ASR::is_a(*val) && ASR::is_a(*ASRUtils::expr_type(val))) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(val); + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); + std::string symengine_var = symengine_stack.push(); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local, + nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + current_scope->add_symbol(s2c(al, symengine_var), arg); + for (auto &item : current_scope->get_scope()) { + if (ASR::is_a(*item.second)) { + ASR::Variable_t *s = ASR::down_cast(item.second); + this->visit_Variable(*s); + } + } + + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); + int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; + switch (static_cast(intrinsic_id)) { + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { + std::string new_name = "basic_const_pi"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + // Create the function call statement for basic_const_pi + ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, + basic_const_pi_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { + std::string new_name = "symbol_set"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "s"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = target; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = intrinsic_func->m_args[0]; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, + symbol_set_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { + ASR::expr_t* value1 = intrinsic_func->m_args[0]; + ASR::expr_t* value2 = intrinsic_func->m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { + ASR::expr_t* value = intrinsic_func->m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t* s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", + target, value); + break; + } + default: { + throw LCompilersException("IntrinsicFunction: `" + + ASRUtils::get_intrinsic_name(intrinsic_id) + + "` is not implemented"); + } + } + std::string new_name = "basic_str"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } // Now create the FunctionCall node for basic_str ASR::symbol_t* basic_str_sym = module_scope->get_symbol(new_name); Vec call_args; @@ -622,21 +1308,14 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitortype != ASR::ttypeType::SymbolicExpression) return; SymbolTable* module_scope = current_scope->parent; - for (size_t i=0; i(*value)){ - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - } - } ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); - std::string target = symengine_queue.push(); + std::string symengine_var = symengine_stack.push(); ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, current_scope, s2c(al, target), nullptr, 0, ASR::intentType::Local, + al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - current_scope->add_symbol(s2c(al, target), arg); + current_scope->add_symbol(s2c(al, symengine_var), arg); for (auto &item : current_scope->get_scope()) { if (ASR::is_a(*item.second)) { ASR::Variable_t *s = ASR::down_cast(item.second); @@ -645,6 +1324,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(intrinsic_id)) { case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { std::string new_name = "basic_const_pi"; @@ -679,7 +1359,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name); - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); Vec call_args; call_args.reserve(al, 1); ASR::call_arg_t call_arg; @@ -730,7 +1409,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name); - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); Vec call_args; call_args.reserve(al, 2); ASR::call_arg_t call_arg1, call_arg2; @@ -746,6 +1424,288 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } + + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", + target, value1, value2); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", + target, value); + break; + } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", + target, value); + break; + } default: { throw LCompilersException("IntrinsicFunction: `" + ASRUtils::get_intrinsic_name(intrinsic_id) @@ -753,6 +1713,100 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; + + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); + std::string symengine_var = symengine_stack.push(); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local, + nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + current_scope->add_symbol(s2c(al, symengine_var), arg); + for (auto &item : current_scope->get_scope()) { + if (ASR::is_a(*item.second)) { + ASR::Variable_t *s = ASR::down_cast(item.second); + this->visit_Variable(*s); + } + } + + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); + ASR::expr_t* cast_arg = x.m_arg; + ASR::expr_t* cast_value = x.m_value; + if (ASR::is_a(*cast_value)) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(cast_value); + int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; + if (static_cast(intrinsic_id) == + LCompilers::ASRUtils::IntrinsicFunctions::SymbolicInteger) { + int const_value = 0; + if (ASR::is_a(*cast_arg)){ + ASR::IntegerConstant_t* const_int = ASR::down_cast(cast_arg); + const_value = const_int->m_n; + } + if (ASR::is_a(*cast_arg)){ + ASR::IntegerUnaryMinus_t *const_int_minus = ASR::down_cast(cast_arg); + ASR::IntegerConstant_t* const_int = ASR::down_cast(const_int_minus->m_value); + const_value = const_int->m_n; + } + std::string new_name = "integer_set_si"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 2); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + + ASR::symbol_t* integer_set_sym = module_scope->get_symbol(new_name); + ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8)); + ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg, + (ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, const_value, cast_type)))); + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = target; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = value; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym, + integer_set_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } + } + } }; void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit, From 3539ccfc8eecf7f40ccf5d34dd2d8b849837b809 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 7 Aug 2023 15:53:03 +0530 Subject: [PATCH 17/21] Added support for assert statements --- src/libasr/pass/replace_symbolic.cpp | 120 +++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 1c15ba06ab..b5c9fe25c5 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -1807,6 +1807,126 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) return; + ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_test); + SymbolTable* module_scope = current_scope->parent; + ASR::expr_t* left_tmp; + ASR::expr_t* right_tmp; + + std::string new_name = "basic_str"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } + ASR::symbol_t* basic_str_sym = module_scope->get_symbol(new_name); + + if(ASR::is_a(*s->m_left)) { + ASR::symbol_t *var_sym = ASR::down_cast(s->m_left)->m_v; + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + + // Now create the FunctionCall node for basic_str + Vec call_args1; + call_args1.reserve(al, 1); + ASR::call_arg_t call_arg1; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = target; + call_args1.push_back(al, call_arg1); + ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args1.p, call_args1.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + left_tmp = function_call1; + } else if(ASR::is_a(*s->m_left)) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(s->m_left); + this->visit_IntrinsicFunction(*intrinsic_func); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + ASR::expr_t* left_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + + // Now create the FunctionCall node for basic_str + Vec call_args1; + call_args1.reserve(al, 1); + ASR::call_arg_t call_arg1; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = left_var; + call_args1.push_back(al, call_arg1); + ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args1.p, call_args1.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + left_tmp = function_call1; + } + + if(ASR::is_a(*s->m_right)) { + ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(s->m_right); + this->visit_IntrinsicFunction(*intrinsic_func); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + ASR::expr_t* right_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + + // Now create the FunctionCall node for basic_str + Vec call_args2; + call_args2.reserve(al, 1); + ASR::call_arg_t call_arg2; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = right_var; + call_args2.push_back(al, call_arg2); + ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args2.p, call_args2.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + right_tmp = function_call2; + } else if (ASR::is_a(*s->m_right)) { + ASR::Cast_t* cast_t = ASR::down_cast(s->m_right); + this->visit_Cast(*cast_t); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + ASR::expr_t* right_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + + // Now create the FunctionCall node for basic_str + Vec call_args2; + call_args2.reserve(al, 1); + ASR::call_arg_t call_arg2; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = right_var; + call_args2.push_back(al, call_arg2); + ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args2.p, call_args2.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + right_tmp = function_call2; + } + ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, + s->m_op, right_tmp, s->m_type, s->m_value)); + + ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); + pass_result.push_back(al, assert_stmt); + } }; void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit, From 1f174d81d05da33df65dcd9c3bb8da9df4f46f60 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 7 Aug 2023 16:04:50 +0530 Subject: [PATCH 18/21] Uncommented all symbolic tests --- integration_tests/CMakeLists.txt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index abee630062..7bbbcd05e5 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -675,12 +675,12 @@ RUN(NAME structs_32 LABELS cpython llvm c) RUN(NAME structs_33 LABELS cpython llvm c) RUN(NAME structs_34 LABELS cpython llvm c) -# RUN(NAME symbolics_01 LABELS cpython_sym c_sym) -# RUN(NAME symbolics_02 LABELS cpython_sym c_sym) -# RUN(NAME symbolics_03 LABELS cpython_sym c_sym) -# RUN(NAME symbolics_04 LABELS cpython_sym c_sym) -# RUN(NAME symbolics_05 LABELS cpython_sym c_sym) -# RUN(NAME symbolics_06 LABELS cpython_sym c_sym) +RUN(NAME symbolics_01 LABELS cpython_sym c_sym llvm_sym) +RUN(NAME symbolics_02 LABELS cpython_sym c_sym llvm_sym) +RUN(NAME symbolics_03 LABELS cpython_sym c_sym llvm_sym) +RUN(NAME symbolics_04 LABELS cpython_sym c_sym llvm_sym) +RUN(NAME symbolics_05 LABELS cpython_sym c_sym llvm_sym) +RUN(NAME symbolics_06 LABELS cpython_sym c_sym llvm_sym) RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym) From 2273ed99f48b5e8ee421b326225cab3cc12ff9bc Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 8 Aug 2023 11:22:28 +0530 Subject: [PATCH 19/21] Added the NOFAST label for symbolic tests --- integration_tests/CMakeLists.txt | 12 ++++----- src/libasr/pass/replace_symbolic.cpp | 39 +++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 7bbbcd05e5..0a3b942523 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -675,12 +675,12 @@ RUN(NAME structs_32 LABELS cpython llvm c) RUN(NAME structs_33 LABELS cpython llvm c) RUN(NAME structs_34 LABELS cpython llvm c) -RUN(NAME symbolics_01 LABELS cpython_sym c_sym llvm_sym) -RUN(NAME symbolics_02 LABELS cpython_sym c_sym llvm_sym) -RUN(NAME symbolics_03 LABELS cpython_sym c_sym llvm_sym) -RUN(NAME symbolics_04 LABELS cpython_sym c_sym llvm_sym) -RUN(NAME symbolics_05 LABELS cpython_sym c_sym llvm_sym) -RUN(NAME symbolics_06 LABELS cpython_sym c_sym llvm_sym) +RUN(NAME symbolics_01 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_02 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_03 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_04 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_05 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_06 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index b5c9fe25c5..828939d197 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -36,7 +36,6 @@ class SymEngine_Stack { std::string top = stack[stack_top]; stack_top--; stack.pop_back(); - if (stack_top == -1) stack.clear(); return top; } }; @@ -1853,8 +1852,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name); if(ASR::is_a(*s->m_left)) { - ASR::symbol_t *var_sym = ASR::down_cast(s->m_left)->m_v; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + ASR::symbol_t *var_sym1 = ASR::down_cast(s->m_left)->m_v; + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); // Now create the FunctionCall node for basic_str Vec call_args1; @@ -1873,6 +1872,23 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(symengine_stack.pop()); ASR::expr_t* left_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + // Now create the FunctionCall node for basic_str + Vec call_args1; + call_args1.reserve(al, 1); + ASR::call_arg_t call_arg1; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = left_var; + call_args1.push_back(al, call_arg1); + ASR::expr_t* function_call1 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args1.p, call_args1.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + left_tmp = function_call1; + } else if (ASR::is_a(*s->m_left)) { + ASR::Cast_t* cast_t = ASR::down_cast(s->m_left); + this->visit_Cast(*cast_t); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + ASR::expr_t* left_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + // Now create the FunctionCall node for basic_str Vec call_args1; call_args1.reserve(al, 1); @@ -1886,7 +1902,22 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*s->m_right)) { + if(ASR::is_a(*s->m_right)) { + ASR::symbol_t *var_sym1 = ASR::down_cast(s->m_right)->m_v; + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + + // Now create the FunctionCall node for basic_str + Vec call_args2; + call_args2.reserve(al, 1); + ASR::call_arg_t call_arg2; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = target; + call_args2.push_back(al, call_arg2); + ASR::expr_t* function_call2 = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args2.p, call_args2.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + right_tmp = function_call2; + } else if(ASR::is_a(*s->m_right)) { ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast(s->m_right); this->visit_IntrinsicFunction(*intrinsic_func); ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); From 0ad03a64f69be5dcc4670b04c86214bbdd572844 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 8 Aug 2023 11:42:53 +0530 Subject: [PATCH 20/21] Minor changes to correct failing tests --- src/libasr/pass/replace_symbolic.cpp | 757 ++++++++++++++------------- 1 file changed, 379 insertions(+), 378 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 828939d197..2759e0e857 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -1305,410 +1305,411 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitortype != ASR::ttypeType::SymbolicExpression) return; - SymbolTable* module_scope = current_scope->parent; - - ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); - std::string symengine_var = symengine_stack.push(); - ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local, - nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, - ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - current_scope->add_symbol(s2c(al, symengine_var), arg); - for (auto &item : current_scope->get_scope()) { - if (ASR::is_a(*item.second)) { - ASR::Variable_t *s = ASR::down_cast(item.second); - this->visit_Variable(*s); - } - } - - int64_t intrinsic_id = x.m_intrinsic_id; - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); - switch (static_cast(intrinsic_id)) { - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { - std::string new_name = "basic_const_pi"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); + if(x.m_type && x.m_type->type == ASR::ttypeType::SymbolicExpression) { + SymbolTable* module_scope = current_scope->parent; - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); + std::string symengine_var = symengine_stack.push(); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local, + nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + current_scope->add_symbol(s2c(al, symengine_var), arg); + for (auto &item : current_scope->get_scope()) { + if (ASR::is_a(*item.second)) { + ASR::Variable_t *s = ASR::down_cast(item.second); + this->visit_Variable(*s); } - - // Create the function call statement for basic_const_pi - ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, - basic_const_pi_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { - std::string new_name = "symbol_set"; - symbolic_dependencies.push_back(new_name); - if (!module_scope->get_symbol(new_name)) { - std::string header = "symengine/cwrapper.h"; - SymbolTable* fn_symtab = al.make_new(module_scope); - Vec args; - args.reserve(al, 1); - ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "x"), arg1); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); - ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, - nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), - nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); - fn_symtab->add_symbol(s2c(al, "s"), arg2); - args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); - - Vec body; - body.reserve(al, 1); - - Vec dep; - dep.reserve(al, 1); + int64_t intrinsic_id = x.m_intrinsic_id; + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); + switch (static_cast(intrinsic_id)) { + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: { + std::string new_name = "basic_const_pi"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } - ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, - fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, - nullptr, ASR::abiType::BindC, ASR::accessType::Public, - ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, - false, false, nullptr, 0, false, false, false, s2c(al, header)); - ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); - module_scope->add_symbol(s2c(al, new_name), new_symbol); + // Create the function call statement for basic_const_pi + ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_const_pi_sym, + basic_const_pi_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: { + std::string new_name = "symbol_set"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg1); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg1))); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, fn_symtab, s2c(al, "s"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "s"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg2))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(s2c(al, new_name), new_symbol); + } - ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); - Vec call_args; - call_args.reserve(al, 2); - ASR::call_arg_t call_arg1, call_arg2; - call_arg1.loc = x.base.base.loc; - call_arg1.m_value = target; - call_arg2.loc = x.base.base.loc; - call_arg2.m_value = x.m_args[0]; - call_args.push_back(al, call_arg1); - call_args.push_back(al, call_arg2); - - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, - symbol_set_sym, call_args.p, call_args.n, nullptr)); - pass_result.push_back(al, stmt); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + ASR::symbol_t* symbol_set_sym = module_scope->get_symbol(new_name); + Vec call_args; + call_args.reserve(al, 2); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = target; + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = x.m_args[0]; + call_args.push_back(al, call_arg1); + call_args.push_back(al, call_arg2); + + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, symbol_set_sym, + symbol_set_sym, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + break; } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_add", + target, value1, value2); + break; } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_sub", + target, value1, value2); + break; } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_mul", + target, value1, value2); + break; } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_div", + target, value1, value2); + break; } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { - ASR::expr_t* value1 = x.m_args[0]; - ASR::expr_t* value2 = x.m_args[1]; - if (ASR::is_a(*value1)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); - } else if (ASR::is_a(*value1)) { - ASR::Cast_t* s = ASR::down_cast(value1); - this->visit_Cast(*s); - ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); - value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_pow", + target, value1, value2); + break; } + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiff: { + ASR::expr_t* value1 = x.m_args[0]; + ASR::expr_t* value2 = x.m_args[1]; + if (ASR::is_a(*value1)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value1); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } else if (ASR::is_a(*value1)) { + ASR::Cast_t* s = ASR::down_cast(value1); + this->visit_Cast(*s); + ASR::symbol_t* var_sym1 = current_scope->get_symbol(symengine_stack.pop()); + value1 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym1)); + } - if (ASR::is_a(*value2)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); - } else if (ASR::is_a(*value2)) { - ASR::Cast_t* s = ASR::down_cast(value2); - this->visit_Cast(*s); - ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); - value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + if (ASR::is_a(*value2)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value2); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } else if (ASR::is_a(*value2)) { + ASR::Cast_t* s = ASR::down_cast(value2); + this->visit_Cast(*s); + ASR::symbol_t* var_sym2 = current_scope->get_symbol(symengine_stack.pop()); + value2 = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym2)); + } + perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", + target, value1, value2); + break; } - perform_symbolic_binary_operation(al, x.base.base.loc, module_scope, "basic_diff", - target, value1, value2); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSin: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", + target, value); + break; } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_sin", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicCos: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", + target, value); + break; } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_cos", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicLog: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", + target, value); + break; } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_log", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExp: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", + target, value); + break; } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_exp", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAbs: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", + target, value); + break; } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_abs", - target, value); - break; - } - case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { - ASR::expr_t* value = x.m_args[0]; - if (ASR::is_a(*value)) { - ASR::IntrinsicFunction_t *s = ASR::down_cast(value); - this->visit_IntrinsicFunction(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); - } else if (ASR::is_a(*value)) { - ASR::Cast_t* s = ASR::down_cast(value); - this->visit_Cast(*s); - ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); - value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicExpand: { + ASR::expr_t* value = x.m_args[0]; + if (ASR::is_a(*value)) { + ASR::IntrinsicFunction_t *s = ASR::down_cast(value); + this->visit_IntrinsicFunction(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } else if (ASR::is_a(*value)) { + ASR::Cast_t* s = ASR::down_cast(value); + this->visit_Cast(*s); + ASR::symbol_t* var_sym = current_scope->get_symbol(symengine_stack.pop()); + value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + } + perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", + target, value); + break; + } + default: { + throw LCompilersException("IntrinsicFunction: `" + + ASRUtils::get_intrinsic_name(intrinsic_id) + + "` is not implemented"); } - perform_symbolic_unary_operation(al, x.base.base.loc, module_scope, "basic_expand", - target, value); - break; - } - default: { - throw LCompilersException("IntrinsicFunction: `" - + ASRUtils::get_intrinsic_name(intrinsic_id) - + "` is not implemented"); } } } From e98e0e8f698e5ab6afa50431d69a9ffaa7bb3ad9 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Tue, 8 Aug 2023 11:50:22 +0530 Subject: [PATCH 21/21] Fixed failing tests --- src/libasr/pass/replace_symbolic.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 2759e0e857..cf003ae577 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -1812,8 +1812,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) return; ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_test); SymbolTable* module_scope = current_scope->parent; - ASR::expr_t* left_tmp; - ASR::expr_t* right_tmp; + ASR::expr_t* left_tmp = nullptr; + ASR::expr_t* right_tmp = nullptr; std::string new_name = "basic_str"; symbolic_dependencies.push_back(new_name);