Skip to content

Commit c210323

Browse files
authored
Added support for assert through SymbolicCompare (#2057)
1 parent a3c5220 commit c210323

File tree

8 files changed

+72
-18
lines changed

8 files changed

+72
-18
lines changed

integration_tests/symbolics_01.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ def main0():
77
x = pi
88
z: S = x + y
99
print(z)
10+
assert(z == pi + y)
11+
assert(z != S(2)*pi + y)
1012

1113
main0()

integration_tests/symbolics_02.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,34 @@ def test_symbolic_operations():
77

88
# Addition
99
z: S = x + y
10-
print(z) # Expected: x + y
10+
assert(z == x + y)
11+
print(z)
1112

1213
# Subtraction
1314
w: S = x - y
14-
print(w) # Expected: x - y
15+
assert(w == x - y)
16+
print(w)
1517

1618
# Multiplication
1719
u: S = x * y
18-
print(u) # Expected: x*y
19-
20+
assert(u == x * y)
21+
print(u)
22+
2023
# Division
2124
v: S = x / y
22-
print(v) # Expected: x/y
23-
25+
assert(v == x / y)
26+
print(v)
27+
2428
# Power
2529
p: S = x ** y
26-
print(p) # Expected: x**y
27-
30+
assert(p == x ** y)
31+
print(p)
32+
2833
# Casting
2934
a: S = S(100)
3035
b: S = S(-100)
3136
c: S = a + b
32-
print(c) # Expected: 0
37+
assert(c == S(0))
38+
print(c)
3339

3440
test_symbolic_operations()

integration_tests/symbolics_03.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@ def test_operator_chaining():
66
x: S = Symbol('x')
77
y: S = Symbol('y')
88
z: S = Symbol('z')
9-
Pi: S = Symbol('pi')
109

1110
a: S = x * w
12-
b: S = a + Pi
11+
b: S = a + pi
1312
c: S = b / z
1413
d: S = c ** w
1514

15+
assert(a == S(2)*x)
16+
assert(b == pi + S(2)*x)
17+
assert(c == (pi + S(2)*x)/z)
18+
assert(d == (pi + S(2)*x)**S(2)/z**S(2))
1619
print(a) # Expected: 2*x
1720
print(b) # Expected: pi + 2*x
1821
print(c) # Expected: (pi + 2*x)/z

integration_tests/symbolics_04.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ def test_chained_operations():
1313
result: S = (w ** S(2) - pi) + S(3)
1414

1515
# Print Statements
16+
assert(result == S(3) + (a -b)**S(2)*(x + y)**S(2)/(z + pi)**S(2) - pi)
1617
print(result)
17-
# Expected: 3 + (a - b)**2*(x + y)**2/(z + pi)**2 - pi
1818

1919
# Additional Variables
2020
c: S = Symbol('c')
@@ -29,13 +29,13 @@ def test_chained_operations():
2929
result = (z + e) * (a - b)
3030

3131
# Print Statements
32+
assert(result == (a - b)*(e + ((S(5) + pi)*(S(-10) + (e + c*d)/f))**(S(2)/(d + f))))
3233
print(result)
33-
# Expected: (a - b)*(e + ((5 + pi)*(-10 + (e + c*d)/f))**(2/(d + f)))
34+
assert(x == (e + c*d)/f)
3435
print(x)
35-
# Expected: (e + c*d)/f
36+
assert(y == (S(5) + pi)*(S(-10) + (e + c*d)/f))
3637
print(y)
37-
# Expected: (5 + pi)*(-10 + (e + c*d)/f)
38+
assert(z == ((S(5) + pi)*(S(-10) + (e + c*d)/f))**(S(2)/(d + f)))
3839
print(z)
39-
# Expected: ((5 + pi)*(-10 + (e + c*d)/f))**(2/(d + f))
4040

4141
test_chained_operations()

src/libasr/ASR.asdl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ expr
290290
| StringChr(expr arg, ttype type, expr? value)
291291

292292
| CPtrCompare(expr left, cmpop op, expr right, ttype type, expr? value)
293+
| SymbolicCompare(expr left, cmpop op, expr right, ttype type, expr? value)
293294

294295
| DictConstant(expr* keys, expr* values, ttype type)
295296
| DictLen(expr arg, ttype type, expr? value)

src/libasr/codegen/asr_to_c.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,15 +1064,19 @@ R"( // Initialise Numpy
10641064
std::string indent(indentation_level*indentation_spaces, ' ');
10651065
std::string out = indent;
10661066
bracket_open++;
1067+
visit_expr(*x.m_test);
1068+
if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)){
1069+
out = symengine_src;
1070+
symengine_src = "";
1071+
out += indent;
1072+
}
10671073
if (x.m_msg) {
10681074
out += "ASSERT_MSG(";
1069-
visit_expr(*x.m_test);
10701075
out += src + ", ";
10711076
visit_expr(*x.m_msg);
10721077
out += src + ");\n";
10731078
} else {
10741079
out += "ASSERT(";
1075-
visit_expr(*x.m_test);
10761080
out += src + ");\n";
10771081
}
10781082
bracket_open--;

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <iostream>
1515
#include <memory>
1616
#include <set>
17+
#include <unordered_set>
1718

1819
#include <libasr/asr.h>
1920
#include <libasr/containers.h>
@@ -97,6 +98,7 @@ class SymEngineQueue {
9798
if(queue_front == -1 || queue_front >= static_cast<int>(queue.size())) {
9899
var = "queue" + std::to_string(queue.size());
99100
queue.push_back(var);
101+
if(queue_front == -1) queue_front++;
100102
symengine_src = indent + "basic " + var + ";\n";
101103
symengine_src += indent + "basic_new_stack(" + var + ");\n";
102104
}
@@ -1952,6 +1954,40 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
19521954
handle_Compare(x);
19531955
}
19541956

1957+
void visit_SymbolicCompare(const ASR::SymbolicCompare_t &x) {
1958+
CHECK_FAST_C_CPP(compiler_options, x)
1959+
self().visit_expr(*x.m_left);
1960+
std::string left_src = symengine_src;
1961+
if(ASR::is_a<ASR::Var_t>(*x.m_left)){
1962+
symengine_queue.pop();
1963+
}
1964+
std::string left = std::move(src);
1965+
1966+
self().visit_expr(*x.m_right);
1967+
std::string right_src = symengine_src;
1968+
if(ASR::is_a<ASR::Var_t>(*x.m_right)){
1969+
symengine_queue.pop();
1970+
}
1971+
std::string right = std::move(src);
1972+
std::string op_str = ASRUtils::cmpop_to_str(x.m_op);
1973+
switch (x.m_op) {
1974+
case (ASR::cmpopType::Eq) : {
1975+
src = "basic_eq(" + left + ", " + right + ") " + op_str + " 1";
1976+
break;
1977+
}
1978+
case (ASR::cmpopType::NotEq) : {
1979+
src = "basic_neq(" + left + ", " + right + ") " + op_str + " 0";
1980+
break;
1981+
}
1982+
default : {
1983+
throw LCompilersException("Symbolic comparison operator: '"
1984+
+ op_str
1985+
+ "' is not implemented");
1986+
}
1987+
}
1988+
symengine_src = left_src + right_src;
1989+
}
1990+
19551991
template<typename T>
19561992
void handle_Compare(const T &x) {
19571993
CHECK_FAST_C_CPP(compiler_options, x)

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6221,6 +6221,8 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
62216221
x.base.base.loc);
62226222
}
62236223
tmp = ASR::make_CPtrCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
6224+
} else if (ASR::is_a<ASR::SymbolicExpression_t>(*dest_type)) {
6225+
tmp = ASR::make_SymbolicCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
62246226
} else {
62256227
throw SemanticError("Compare not supported for type: " + ASRUtils::type_to_str_python(dest_type),
62266228
x.base.base.loc);

0 commit comments

Comments
 (0)