From 2d6bc11044670400ddb2bbbdb692f7145da4ee5f Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 2 Jul 2023 12:53:02 +0530 Subject: [PATCH 1/3] Added support for symbolic Expand & Differentiation --- src/libasr/codegen/asr_to_c_cpp.h | 20 ++++++ src/libasr/pass/intrinsic_function_registry.h | 67 +++++++++++++++++-- src/lpython/semantics/python_ast_to_asr.cpp | 12 +++- src/lpython/semantics/python_attribute_eval.h | 48 ++++++++++++- 4 files changed, 139 insertions(+), 8 deletions(-) diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index 98d4cccb94..6f845bd79e 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -2755,6 +2755,10 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { src = performSymbolicOperation("basic_pow", x); return; } + case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicDiff)): { + src = performSymbolicOperation("basic_diff", x); + return; + } case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicPi)): { headers.insert("symengine/cwrapper.h"); LCOMPILERS_ASSERT(x.n_args == 0); @@ -2781,6 +2785,22 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { src = target; return; } + case (static_cast(ASRUtils::IntrinsicFunctions::SymbolicExpand)): { + headers.insert("symengine/cwrapper.h"); + LCOMPILERS_ASSERT(x.n_args == 1); + std::string target = symengine_queue.push(); + std::string target_src = symengine_src; + this->visit_expr(*x.m_args[0]); + std::string arg1 = src; + std::string arg1_src = symengine_src; + if (ASR::is_a(*x.m_args[0])) { + symengine_queue.pop(); + } + symengine_src = target_src + arg1_src; + symengine_src += indent + "basic_expand(" + target + ", " + arg1 + ");\n"; + src = target; + return; + } default : { throw LCompilersException("IntrinsicFunction: `" + ASRUtils::get_intrinsic_name(x.m_intrinsic_id) diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index d82a8c253e..daa0ab8a45 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -72,6 +72,8 @@ enum class IntrinsicFunctions : int64_t { SymbolicPow, SymbolicPi, SymbolicInteger, + SymbolicDiff, + SymbolicExpand, Sum, // ... }; @@ -2056,7 +2058,7 @@ namespace SymbolicSymbol { } // namespace SymbolicSymbol -#define create_symbolic_binop_macro(X) \ +#define create_symbolic_binary_macro(X) \ namespace X{ \ \ static inline void verify_args(const ASR::IntrinsicFunction_t& x, \ @@ -2107,11 +2109,12 @@ namespace X{ } \ } // namespace X -create_symbolic_binop_macro(SymbolicAdd) -create_symbolic_binop_macro(SymbolicSub) -create_symbolic_binop_macro(SymbolicMul) -create_symbolic_binop_macro(SymbolicDiv) -create_symbolic_binop_macro(SymbolicPow) +create_symbolic_binary_macro(SymbolicAdd) +create_symbolic_binary_macro(SymbolicSub) +create_symbolic_binary_macro(SymbolicMul) +create_symbolic_binary_macro(SymbolicDiv) +create_symbolic_binary_macro(SymbolicPow) +create_symbolic_binary_macro(SymbolicDiff) namespace SymbolicPi { @@ -2166,6 +2169,46 @@ namespace SymbolicInteger { } } // namespace SymbolicInteger +namespace SymbolicExpand { + + static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) { + const Location& loc = x.base.base.loc; + ASRUtils::require_impl(x.n_args == 1, + "SymbolicExpand must have exactly 1 input argument", + loc, diagnostics); + + ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); + ASRUtils::require_impl(ASR::is_a(*input_type), + "SymbolicExpand expects an argument of type SymbolicExpression", + x.base.base.loc, diagnostics); + } + + static inline ASR::expr_t *eval_SymbolicExpand(Allocator &/*al*/, + const Location &/*loc*/, Vec& /*args*/) { + // TODO + return nullptr; + } + + static inline ASR::asr_t* create_SymbolicExpand(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + if (args.size() != 1) { + err("Intrinsic expand function accepts exactly 1 argument", loc); + } + + ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); + if(!ASR::is_a(*argtype)) { + err("Argument of SymbolicExpand function must be of type SymbolicExpression", + args[0]->base.loc); + } + + ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); + return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicExpand, + static_cast(ASRUtils::IntrinsicFunctions::SymbolicExpand), 0, to_type); + } + +} // namespace SymbolicExpand + namespace IntrinsicFunctionRegistry { static const std::map(ASRUtils::IntrinsicFunctions::SymbolicInteger), {nullptr, &SymbolicInteger::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicDiff), + {nullptr, &SymbolicDiff::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicExpand), + {nullptr, &SymbolicExpand::verify_args}}, }; static const std::map& intrinsic_function_id_to_name = { @@ -2282,6 +2329,10 @@ namespace IntrinsicFunctionRegistry { "pi"}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicInteger), "SymbolicInteger"}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicDiff), + "SymbolicDiff"}, + {static_cast(ASRUtils::IntrinsicFunctions::SymbolicExpand), + "SymbolicExpand"}, {static_cast(ASRUtils::IntrinsicFunctions::Any), "any"}, {static_cast(ASRUtils::IntrinsicFunctions::Sum), @@ -2319,6 +2370,8 @@ namespace IntrinsicFunctionRegistry { {"SymbolicPow", {&SymbolicPow::create_SymbolicPow, &SymbolicPow::eval_SymbolicPow}}, {"pi", {&SymbolicPi::create_SymbolicPi, &SymbolicPi::eval_SymbolicPi}}, {"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}}, + {"diff", {&SymbolicDiff::create_SymbolicDiff, &SymbolicDiff::eval_SymbolicDiff}}, + {"expand", {&SymbolicExpand::create_SymbolicExpand, &SymbolicExpand::eval_SymbolicExpand}}, }; static inline bool is_intrinsic_function(const std::string& name) { @@ -2433,6 +2486,8 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicPow) INTRINSIC_NAME_CASE(SymbolicPi) INTRINSIC_NAME_CASE(SymbolicInteger) + INTRINSIC_NAME_CASE(SymbolicDiff) + INTRINSIC_NAME_CASE(SymbolicExpand) INTRINSIC_NAME_CASE(Sum) default : { throw LCompilersException("pickle: intrinsic_id not implemented"); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index ecfe15b618..0805d6d176 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -694,6 +694,11 @@ class CommonVisitor : public AST::BaseVisitor { return; } + void handle_symbolic_attribute(ASR::expr_t *s, std::string attr_name, + const Location &loc, Vec &args) { + tmp = attr_handler.get_symbolic_attribute(s, attr_name, al, loc, args, diag); + return; + } void fill_expr_in_ttype_t(std::vector& exprs, ASR::dimension_t* dims, size_t n_dims) { for( size_t i = 0; i < n_dims; i++ ) { @@ -7113,6 +7118,11 @@ class BodyVisitor : public CommonVisitor { handle_string_attributes(se, args, at->m_attr, loc); return; } + ASR::ttype_t *type = ASRUtils::expr_type(se); + if (ASR::is_a(*type)) { + handle_symbolic_attribute(se, at->m_attr, loc, eles); + return; + } handle_builtin_attribute(se, at->m_attr, loc, eles); return; } @@ -7231,7 +7241,7 @@ class BodyVisitor : public CommonVisitor { if (!s) { std::set not_cpython_builtin = { - "sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", + "sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "sum" // For sum called over lists }; if (ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(call_name) && diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 5ce7834bb1..fe4b01771e 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -15,7 +15,7 @@ struct AttributeHandler { typedef ASR::asr_t* (*attribute_eval_callback)(ASR::expr_t*, Allocator &, const Location &, Vec &, diag::Diagnostics &); - std::map attribute_map; + std::map attribute_map, symbolic_attribute_map; std::set modify_attr_set; AttributeHandler() { @@ -40,6 +40,11 @@ struct AttributeHandler { modify_attr_set = {"list@append", "list@remove", "list@reverse", "list@clear", "list@insert", "list@pop", "set@pop", "set@add", "set@remove", "dict@pop"}; + + symbolic_attribute_map = { + {"diff", &eval_symbolic_diff}, + {"expand", &eval_symbolic_expand} + }; } std::string get_type_name(ASR::ttype_t *t) { @@ -82,6 +87,19 @@ struct AttributeHandler { } } + ASR::asr_t* get_symbolic_attribute(ASR::expr_t *e, std::string attr_name, + Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &diag) { + std::string key = attr_name; + auto search = symbolic_attribute_map.find(key); + if (search != symbolic_attribute_map.end()) { + attribute_eval_callback cb = search->second; + return cb(e, al, loc, args, diag); + } else { + throw SemanticError("S." + attr_name + " is not implemented yet", + loc); + } + } + static ASR::asr_t* eval_int_bit_length(ASR::expr_t *s, Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &/*diag*/) { if (args.size() != 0) { @@ -388,6 +406,34 @@ struct AttributeHandler { return make_DictPop_t(al, loc, s, args[0], value_type, nullptr); } + static ASR::asr_t* eval_symbolic_diff(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicFunctionRegistry::get_create_function("diff"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + + static ASR::asr_t* eval_symbolic_expand(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicFunctionRegistry::get_create_function("expand"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython From 96666408d7d3ffff6fc041d3362c5d162645c5b2 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 2 Jul 2023 13:18:17 +0530 Subject: [PATCH 2/3] Added tests --- integration_tests/CMakeLists.txt | 1 + integration_tests/symbolics_05.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 integration_tests/symbolics_05.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index d458f72e80..23175f5d47 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -602,6 +602,7 @@ RUN(NAME symbolics_01 LABELS cpython_sym c_sym) RUN(NAME symbolics_02 LABELS cpython_sym c_sym) RUN(NAME symbolics_03 LABELS cpython_sym c_sym) RUN(NAME symbolics_04 LABELS cpython_sym c_sym) +RUN(NAME symbolics_05 LABELS cpython_sym c_sym) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/symbolics_05.py b/integration_tests/symbolics_05.py new file mode 100644 index 0000000000..7876e55bc0 --- /dev/null +++ b/integration_tests/symbolics_05.py @@ -0,0 +1,26 @@ +from sympy import Symbol, expand, diff +from lpython import S + +def test_operations(): + x: S = Symbol('x') + y: S = Symbol('y') + z: S = Symbol('z') + + # test expand + a: S = (x + y)**S(2) + b: S = (x + y + z)**S(3) + assert(a.expand() == S(2)*x*y + x**S(2) + y**S(2)) + assert(expand(b) == S(3)*x*y**S(2) + S(3)*x*z**S(2) + S(3)*x**S(2)*y + S(3)*x**S(2)*z +\ + S(3)*y*z**S(2) + S(3)*y**S(2)*z + S(6)*x*y*z + x**S(3) + y**S(3) + z**S(3)) + print(a.expand()) + print(expand(b)) + + # test diff + c: S = (x + y)**S(2) + d: S = (x + y + z)**S(3) + assert(c.diff(x) == S(2)*(x + y)) + assert(diff(d, x) == S(3)*(x + y + z)**S(2)) + print(c.diff(x)) + print(diff(d, x)) + +test_operations() \ No newline at end of file From 315e020555ede176f9a25c56d5d2c39236d10688 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 2 Jul 2023 13:23:02 +0530 Subject: [PATCH 3/3] Refactored test file --- integration_tests/symbolics_05.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/integration_tests/symbolics_05.py b/integration_tests/symbolics_05.py index 7876e55bc0..9214298d93 100644 --- a/integration_tests/symbolics_05.py +++ b/integration_tests/symbolics_05.py @@ -5,10 +5,10 @@ def test_operations(): x: S = Symbol('x') y: S = Symbol('y') z: S = Symbol('z') - - # test expand a: S = (x + y)**S(2) b: S = (x + y + z)**S(3) + + # test expand assert(a.expand() == S(2)*x*y + x**S(2) + y**S(2)) assert(expand(b) == S(3)*x*y**S(2) + S(3)*x*z**S(2) + S(3)*x**S(2)*y + S(3)*x**S(2)*z +\ S(3)*y*z**S(2) + S(3)*y**S(2)*z + S(6)*x*y*z + x**S(3) + y**S(3) + z**S(3)) @@ -16,11 +16,9 @@ def test_operations(): print(expand(b)) # test diff - c: S = (x + y)**S(2) - d: S = (x + y + z)**S(3) - assert(c.diff(x) == S(2)*(x + y)) - assert(diff(d, x) == S(3)*(x + y + z)**S(2)) - print(c.diff(x)) - print(diff(d, x)) + assert(a.diff(x) == S(2)*(x + y)) + assert(diff(b, x) == S(3)*(x + y + z)**S(2)) + print(a.diff(x)) + print(diff(b, x)) test_operations() \ No newline at end of file