diff --git a/integration_tests/symbolics_01.py b/integration_tests/symbolics_01.py index 9c1a0e9096..525cf6ab4c 100644 --- a/integration_tests/symbolics_01.py +++ b/integration_tests/symbolics_01.py @@ -7,5 +7,7 @@ def main0(): x = pi z: S = x + y print(z) + assert(z == pi + y) + assert(z != S(2)*pi + y) main0() \ No newline at end of file diff --git a/integration_tests/symbolics_02.py b/integration_tests/symbolics_02.py index 2bc24e8511..f22a432606 100644 --- a/integration_tests/symbolics_02.py +++ b/integration_tests/symbolics_02.py @@ -7,28 +7,34 @@ def test_symbolic_operations(): # Addition z: S = x + y - print(z) # Expected: x + y + assert(z == x + y) + print(z) # Subtraction w: S = x - y - print(w) # Expected: x - y + assert(w == x - y) + print(w) # Multiplication u: S = x * y - print(u) # Expected: x*y - + assert(u == x * y) + print(u) + # Division v: S = x / y - print(v) # Expected: x/y - + assert(v == x / y) + print(v) + # Power p: S = x ** y - print(p) # Expected: x**y - + assert(p == x ** y) + print(p) + # Casting a: S = S(100) b: S = S(-100) c: S = a + b - print(c) # Expected: 0 + assert(c == S(0)) + print(c) test_symbolic_operations() diff --git a/integration_tests/symbolics_03.py b/integration_tests/symbolics_03.py index 12295fd0e7..8dc91a9720 100644 --- a/integration_tests/symbolics_03.py +++ b/integration_tests/symbolics_03.py @@ -6,13 +6,16 @@ def test_operator_chaining(): x: S = Symbol('x') y: S = Symbol('y') z: S = Symbol('z') - Pi: S = Symbol('pi') a: S = x * w - b: S = a + Pi + b: S = a + pi c: S = b / z d: S = c ** w + assert(a == S(2)*x) + assert(b == pi + S(2)*x) + assert(c == (pi + S(2)*x)/z) + assert(d == (pi + S(2)*x)**S(2)/z**S(2)) print(a) # Expected: 2*x print(b) # Expected: pi + 2*x print(c) # Expected: (pi + 2*x)/z diff --git a/integration_tests/symbolics_04.py b/integration_tests/symbolics_04.py index 32cb8df66c..63d30bf3f6 100644 --- a/integration_tests/symbolics_04.py +++ b/integration_tests/symbolics_04.py @@ -13,8 +13,8 @@ def test_chained_operations(): result: S = (w ** S(2) - pi) + S(3) # Print Statements + assert(result == S(3) + (a -b)**S(2)*(x + y)**S(2)/(z + pi)**S(2) - pi) print(result) - # Expected: 3 + (a - b)**2*(x + y)**2/(z + pi)**2 - pi # Additional Variables c: S = Symbol('c') @@ -29,13 +29,13 @@ def test_chained_operations(): result = (z + e) * (a - b) # Print Statements + assert(result == (a - b)*(e + ((S(5) + pi)*(S(-10) + (e + c*d)/f))**(S(2)/(d + f)))) print(result) - # Expected: (a - b)*(e + ((5 + pi)*(-10 + (e + c*d)/f))**(2/(d + f))) + assert(x == (e + c*d)/f) print(x) - # Expected: (e + c*d)/f + assert(y == (S(5) + pi)*(S(-10) + (e + c*d)/f)) print(y) - # Expected: (5 + pi)*(-10 + (e + c*d)/f) + assert(z == ((S(5) + pi)*(S(-10) + (e + c*d)/f))**(S(2)/(d + f))) print(z) - # Expected: ((5 + pi)*(-10 + (e + c*d)/f))**(2/(d + f)) test_chained_operations() \ No newline at end of file diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 6d92af736d..e4d54f237c 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -290,6 +290,7 @@ expr | StringChr(expr arg, ttype type, expr? value) | CPtrCompare(expr left, cmpop op, expr right, ttype type, expr? value) + | SymbolicCompare(expr left, cmpop op, expr right, ttype type, expr? value) | DictConstant(expr* keys, expr* values, ttype type) | DictLen(expr arg, ttype type, expr? value) diff --git a/src/libasr/codegen/asr_to_c.cpp b/src/libasr/codegen/asr_to_c.cpp index 1f649059b6..45fa2534cc 100644 --- a/src/libasr/codegen/asr_to_c.cpp +++ b/src/libasr/codegen/asr_to_c.cpp @@ -1064,15 +1064,19 @@ R"( // Initialise Numpy std::string indent(indentation_level*indentation_spaces, ' '); std::string out = indent; bracket_open++; + visit_expr(*x.m_test); + if (ASR::is_a(*x.m_test)){ + out = symengine_src; + symengine_src = ""; + out += indent; + } if (x.m_msg) { out += "ASSERT_MSG("; - visit_expr(*x.m_test); out += src + ", "; visit_expr(*x.m_msg); out += src + ");\n"; } else { out += "ASSERT("; - visit_expr(*x.m_test); out += src + ");\n"; } bracket_open--; diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index e9610667f4..50268a9148 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -97,6 +98,7 @@ class SymEngineQueue { if(queue_front == -1 || queue_front >= static_cast(queue.size())) { var = "queue" + std::to_string(queue.size()); queue.push_back(var); + if(queue_front == -1) queue_front++; symengine_src = indent + "basic " + var + ";\n"; symengine_src += indent + "basic_new_stack(" + var + ");\n"; } @@ -1952,6 +1954,40 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) { handle_Compare(x); } + void visit_SymbolicCompare(const ASR::SymbolicCompare_t &x) { + CHECK_FAST_C_CPP(compiler_options, x) + self().visit_expr(*x.m_left); + std::string left_src = symengine_src; + if(ASR::is_a(*x.m_left)){ + symengine_queue.pop(); + } + std::string left = std::move(src); + + self().visit_expr(*x.m_right); + std::string right_src = symengine_src; + if(ASR::is_a(*x.m_right)){ + symengine_queue.pop(); + } + std::string right = std::move(src); + std::string op_str = ASRUtils::cmpop_to_str(x.m_op); + switch (x.m_op) { + case (ASR::cmpopType::Eq) : { + src = "basic_eq(" + left + ", " + right + ") " + op_str + " 1"; + break; + } + case (ASR::cmpopType::NotEq) : { + src = "basic_neq(" + left + ", " + right + ") " + op_str + " 0"; + break; + } + default : { + throw LCompilersException("Symbolic comparison operator: '" + + op_str + + "' is not implemented"); + } + } + symengine_src = left_src + right_src; + } + template void handle_Compare(const T &x) { CHECK_FAST_C_CPP(compiler_options, x) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 5952ffcd2f..0d7b3d52e9 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6227,6 +6227,8 @@ class BodyVisitor : public CommonVisitor { x.base.base.loc); } tmp = ASR::make_CPtrCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + } else if (ASR::is_a(*dest_type)) { + tmp = ASR::make_SymbolicCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); } else { throw SemanticError("Compare not supported for type: " + ASRUtils::type_to_str_python(dest_type), x.base.base.loc);