Skip to content

Commit dc3510b

Browse files
authored
Added support for handling chaining of Symbolic operators and printing without assignment . (#2051)
1 parent 180ba71 commit dc3510b

File tree

5 files changed

+176
-96
lines changed

5 files changed

+176
-96
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ RUN(NAME structs_30 LABELS cpython llvm c)
593593
RUN(NAME symbolics_01 LABELS cpython_sym c_sym)
594594
RUN(NAME symbolics_02 LABELS cpython_sym c_sym)
595595
RUN(NAME symbolics_03 LABELS cpython_sym c_sym)
596+
RUN(NAME symbolics_04 LABELS cpython_sym c_sym)
596597

597598
RUN(NAME sizeof_01 LABELS llvm c
598599
EXTRAFILES sizeof_01b.c)

integration_tests/symbolics_04.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from sympy import Symbol, pi, S
2+
from lpython import S
3+
4+
def test_chained_operations():
5+
x: S = Symbol('x')
6+
y: S = Symbol('y')
7+
z: S = Symbol('z')
8+
a: S = Symbol('a')
9+
b: S = Symbol('b')
10+
11+
# Chained Operations
12+
w: S = (x + y) * ((a - b) / (pi + z))
13+
result: S = (w ** S(2) - pi) + S(3)
14+
15+
# Print Statements
16+
print(result)
17+
# Expected: 3 + (a - b)**2*(x + y)**2/(z + pi)**2 - pi
18+
19+
# Additional Variables
20+
c: S = Symbol('c')
21+
d: S = Symbol('d')
22+
e: S = Symbol('e')
23+
f: S = Symbol('f')
24+
25+
# Chained Operations with Additional Variables
26+
x = (c * d + e) / f
27+
y = (x - S(10)) * (pi + S(5))
28+
z = y ** (S(2) / (f + d))
29+
result = (z + e) * (a - b)
30+
31+
# Print Statements
32+
print(result)
33+
# Expected: (a - b)*(e + ((5 + pi)*(-10 + (e + c*d)/f))**(2/(d + f)))
34+
print(x)
35+
# Expected: (e + c*d)/f
36+
print(y)
37+
# Expected: (5 + pi)*(-10 + (e + c*d)/f)
38+
print(z)
39+
# Expected: ((5 + pi)*(-10 + (e + c*d)/f))**(2/(d + f))
40+
41+
test_chained_operations()

src/libasr/codegen/asr_to_c.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,8 +1138,12 @@ R"( // Initialise Numpy
11381138
if( ASRUtils::is_array(value_type) ) {
11391139
src += "->data";
11401140
}
1141-
if (value_type->type == ASR::ttypeType::List ||
1142-
value_type->type == ASR::ttypeType::Tuple) {
1141+
if(ASR::is_a<ASR::SymbolicExpression_t>(*value_type)) {
1142+
out += symengine_src;
1143+
symengine_src = "";
1144+
}
1145+
if( ASR::is_a<ASR::List_t>(*value_type) ||
1146+
ASR::is_a<ASR::Tuple_t>(*value_type)) {
11431147
tmp_gen += "\"";
11441148
if (!v.empty()) {
11451149
for (auto &s: v) {
@@ -1156,13 +1160,16 @@ R"( // Initialise Numpy
11561160
}
11571161
tmp_gen += c_ds_api->get_print_type(value_type, ASR::is_a<ASR::ArrayItem_t>(*x.m_values[i]));
11581162
v.push_back(src);
1159-
if (value_type->type == ASR::ttypeType::Complex) {
1163+
if (ASR::is_a<ASR::Complex_t>(*value_type)) {
11601164
v.pop_back();
11611165
v.push_back("creal(" + src + ")");
11621166
v.push_back("cimag(" + src + ")");
1163-
} else if(value_type->type == ASR::ttypeType::SymbolicExpression){
1167+
} else if(ASR::is_a<ASR::SymbolicExpression_t>(*value_type)){
11641168
v.pop_back();
11651169
v.push_back("basic_str(" + src + ")");
1170+
if(ASR::is_a<ASR::Var_t>(*x.m_values[i])) {
1171+
symengine_queue.pop();
1172+
}
11661173
}
11671174
if (i+1!=x.n_values) {
11681175
tmp_gen += "\%s";

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 123 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,32 @@ struct CPPDeclarationOptions: public DeclarationOptions {
8383
}
8484
};
8585

86+
class SymEngineQueue {
87+
public:
88+
std::vector<std::string> queue;
89+
int queue_front = -1;
90+
std::string& symengine_src;
91+
92+
SymEngineQueue(std::string& symengine_src) : symengine_src(symengine_src) {}
93+
94+
std::string push() {
95+
std::string indent(4, ' ');
96+
std::string var;
97+
if(queue_front == -1 || queue_front >= static_cast<int>(queue.size())) {
98+
var = "queue" + std::to_string(queue.size());
99+
queue.push_back(var);
100+
symengine_src = indent + "basic " + var + ";\n";
101+
symengine_src += indent + "basic_new_stack(" + var + ");\n";
102+
}
103+
return queue[queue_front++];
104+
}
105+
106+
void pop() {
107+
LCOMPILERS_ASSERT(queue_front != -1 && queue_front < static_cast<int>(queue.size()));
108+
queue_front++;
109+
}
110+
};
111+
86112
template <class Struct>
87113
class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
88114
{
@@ -115,6 +141,8 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
115141
bool is_c;
116142
std::set<std::string> headers, user_headers, user_defines;
117143
std::vector<std::string> tmp_buffer_src;
144+
std::string symengine_src;
145+
SymEngineQueue symengine_queue{symengine_src};
118146

119147
SymbolTable* global_scope;
120148
int64_t lower_bound;
@@ -1178,6 +1206,17 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
11781206
target = "&" + target;
11791207
}
11801208
}
1209+
if( ASR::is_a<ASR::SymbolicExpression_t>(*value_type) ) {
1210+
if(ASR::is_a<ASR::Var_t>(*x.m_value)){
1211+
src = indent + "basic_assign(" + target + ", " + value + ");\n";
1212+
symengine_queue.pop();
1213+
symengine_queue.pop();
1214+
return;
1215+
}
1216+
src = symengine_src;
1217+
symengine_src = "";
1218+
return;
1219+
}
11811220
if( !from_std_vector_helper.empty() ) {
11821221
src = from_std_vector_helper;
11831222
} else {
@@ -1243,12 +1282,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
12431282
src += alloc + indent + c_ds_api->get_deepcopy(m_target_type, value, target) + "\n";
12441283
}
12451284
} else {
1246-
if (m_target_type->type == ASR::ttypeType::SymbolicExpression){
1247-
ASR::expr_t* m_value_expr = x.m_value;
1248-
src += alloc + indent + c_ds_api->get_deepcopy_symbolic(m_value_expr, value, target) + "\n";
1249-
} else {
1250-
src += alloc + indent + c_ds_api->get_deepcopy(m_target_type, value, target) + "\n";
1251-
}
1285+
src += alloc + indent + c_ds_api->get_deepcopy(m_target_type, value, target) + "\n";
12521286
}
12531287
} else {
12541288
src += indent + c_ds_api->get_deepcopy(m_target_type, value, target) + "\n";
@@ -1646,6 +1680,15 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
16461680
src = std::string(ASR::down_cast<ASR::Variable_t>(s)->m_name);
16471681
}
16481682
last_expr_precedence = 2;
1683+
ASR::ttype_t* var_type = sv->m_type;
1684+
if( ASR::is_a<ASR::SymbolicExpression_t>(*var_type)) {
1685+
std::string var_name = std::string(ASR::down_cast<ASR::Variable_t>(s)->m_name);
1686+
symengine_queue.queue.push_back(var_name);
1687+
if (symengine_queue.queue_front == -1) {
1688+
symengine_queue.queue_front = 0;
1689+
}
1690+
symengine_src = "";
1691+
}
16491692
}
16501693

16511694
void visit_StructInstanceMember(const ASR::StructInstanceMember_t& x) {
@@ -1858,6 +1901,8 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
18581901
break;
18591902
}
18601903
case (ASR::cast_kindType::IntegerToSymbolicExpression): {
1904+
self().visit_expr(*x.m_value);
1905+
last_expr_precedence = 2;
18611906
break;
18621907
}
18631908
default : throw CodeGenError("Cast kind " + std::to_string(x.m_kind) + " not implemented",
@@ -2591,8 +2636,34 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
25912636
out += func_name; break; \
25922637
}
25932638

2639+
std::string performSymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
2640+
headers.insert("symengine/cwrapper.h");
2641+
std::string indent(4, ' ');
2642+
LCOMPILERS_ASSERT(x.n_args == 2);
2643+
std::string target = symengine_queue.push();
2644+
std::string target_src = symengine_src;
2645+
this->visit_expr(*x.m_args[0]);
2646+
std::string arg1 = src;
2647+
std::string arg1_src = symengine_src;
2648+
// Check if x.m_args[0] is a Var
2649+
if (ASR::is_a<ASR::Var_t>(*x.m_args[0])) {
2650+
symengine_queue.pop();
2651+
}
2652+
this->visit_expr(*x.m_args[1]);
2653+
std::string arg2 = src;
2654+
std::string arg2_src = symengine_src;
2655+
// Check if x.m_args[0] is a Var
2656+
if (ASR::is_a<ASR::Var_t>(*x.m_args[1])) {
2657+
symengine_queue.pop();
2658+
}
2659+
symengine_src = target_src + arg1_src + arg2_src;
2660+
symengine_src += indent + functionName + "(" + target + ", " + arg1 + ", " + arg2 + ");\n";
2661+
return target;
2662+
}
2663+
25942664
void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t &x) {
25952665
std::string out;
2666+
std::string indent(4, ' ');
25962667
switch (x.m_intrinsic_id) {
25972668
SET_INTRINSIC_NAME(Sin, "sin");
25982669
SET_INTRINSIC_NAME(Cos, "cos");
@@ -2607,22 +2678,51 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
26072678
SET_INTRINSIC_NAME(Exp, "exp");
26082679
SET_INTRINSIC_NAME(Exp2, "exp2");
26092680
SET_INTRINSIC_NAME(Expm1, "expm1");
2610-
SET_INTRINSIC_NAME(SymbolicSymbol, "Symbol");
2611-
SET_INTRINSIC_NAME(SymbolicInteger, "Integer");
2612-
SET_INTRINSIC_NAME(SymbolicPi, "pi");
2613-
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)):
2614-
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub)):
2615-
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul)):
2616-
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv)):
2681+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)): {
2682+
src = performSymbolicOperation("basic_add", x);
2683+
return;
2684+
}
2685+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub)): {
2686+
src = performSymbolicOperation("basic_sub", x);
2687+
return;
2688+
}
2689+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul)): {
2690+
src = performSymbolicOperation("basic_mul", x);
2691+
return;
2692+
}
2693+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv)): {
2694+
src = performSymbolicOperation("basic_div", x);
2695+
return;
2696+
}
26172697
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPow)): {
2618-
LCOMPILERS_ASSERT(x.n_args == 2);
2698+
src = performSymbolicOperation("basic_pow", x);
2699+
return;
2700+
}
2701+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi)): {
2702+
headers.insert("symengine/cwrapper.h");
2703+
LCOMPILERS_ASSERT(x.n_args == 0);
2704+
std::string target = symengine_queue.push();
2705+
symengine_src += indent + "basic_const_pi(" + target + ");\n";
2706+
src = target;
2707+
return;
2708+
}
2709+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol)): {
2710+
headers.insert("symengine/cwrapper.h");
2711+
LCOMPILERS_ASSERT(x.n_args == 1);
26192712
this->visit_expr(*x.m_args[0]);
2620-
std::string arg1 = src;
2621-
this->visit_expr(*x.m_args[1]);
2622-
std::string arg2 = src;
2623-
out = arg1 + "," + arg2;
2624-
src = out;
2625-
break;
2713+
std::string target = symengine_queue.push();
2714+
symengine_src += indent + "symbol_set(" + target + ", " + src + ");\n";
2715+
src = target;
2716+
return;
2717+
}
2718+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger)): {
2719+
headers.insert("symengine/cwrapper.h");
2720+
LCOMPILERS_ASSERT(x.n_args == 1);
2721+
this->visit_expr(*x.m_args[0]);
2722+
std::string target = symengine_queue.push();
2723+
symengine_src += indent + "integer_set_si(" + target + ", " + src + ");\n";
2724+
src = target;
2725+
return;
26262726
}
26272727
default : {
26282728
throw LCompilersException("IntrinsicFunction: `"
@@ -2631,16 +2731,9 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
26312731
}
26322732
}
26332733
headers.insert("math.h");
2634-
if (x.n_args == 0){
2635-
src = out;
2636-
} else if (x.n_args == 1) {
2637-
this->visit_expr(*x.m_args[0]);
2638-
if ((x.m_intrinsic_id != static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol)) &&
2639-
(x.m_intrinsic_id != static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger))) {
2640-
out += "(" + src + ")";
2641-
src = out;
2642-
}
2643-
}
2734+
this->visit_expr(*x.m_args[0]);
2735+
out += "(" + src + ")";
2736+
src = out;
26442737
}
26452738
};
26462739

src/libasr/codegen/c_utils.h

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -445,68 +445,6 @@ class CCPPDSUtils {
445445
return result;
446446
}
447447

448-
std::string generate_binary_operator_code(std::string value, std::string target, std::string operatorName) {
449-
size_t delimiterPos = value.find(",");
450-
std::string leftPart = value.substr(0, delimiterPos);
451-
std::string rightPart = value.substr(delimiterPos + 1);
452-
std::string result = operatorName + "(" + target + ", " + leftPart + ", " + rightPart + ");";
453-
return result;
454-
}
455-
456-
std::string get_deepcopy_symbolic(ASR::expr_t *value_expr, std::string value, std::string target) {
457-
std::string result;
458-
if (ASR::is_a<ASR::Var_t>(*value_expr)) {
459-
result = "basic_assign(" + target + ", " + value + ");";
460-
} else if (ASR::is_a<ASR::IntrinsicFunction_t>(*value_expr)) {
461-
ASR::IntrinsicFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicFunction_t>(value_expr);
462-
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
463-
switch (static_cast<LCompilers::ASRUtils::IntrinsicFunctions>(intrinsic_id)) {
464-
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSymbol: {
465-
result = "symbol_set(" + target + ", " + value + ");";
466-
break;
467-
}
468-
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: {
469-
result = generate_binary_operator_code(value, target, "basic_add");
470-
break;
471-
}
472-
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: {
473-
result = generate_binary_operator_code(value, target, "basic_sub");
474-
break;
475-
}
476-
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: {
477-
result = generate_binary_operator_code(value, target, "basic_mul");
478-
break;
479-
}
480-
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: {
481-
result = generate_binary_operator_code(value, target, "basic_div");
482-
break;
483-
}
484-
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: {
485-
result = generate_binary_operator_code(value, target, "basic_pow");
486-
break;
487-
}
488-
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: {
489-
result = "basic_const_pi(" + target + ");";
490-
break;
491-
}
492-
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicInteger: {
493-
result = "integer_set_si(" + target + ", " + value + ");";
494-
break;
495-
}
496-
default: {
497-
throw LCompilersException("IntrinsicFunction: `"
498-
+ LCompilers::ASRUtils::get_intrinsic_name(intrinsic_id)
499-
+ "` is not implemented");
500-
}
501-
}
502-
} else if (ASR::is_a<ASR::Cast_t>(*value_expr)) {
503-
ASR::Cast_t* cast_expr = ASR::down_cast<ASR::Cast_t>(value_expr);
504-
std::string cast_value_expr = get_deepcopy_symbolic(cast_expr->m_value, value, target);
505-
return cast_value_expr;
506-
}
507-
return result;
508-
}
509-
510448
std::string get_type(ASR::ttype_t *t) {
511449
LCOMPILERS_ASSERT(CUtils::is_non_primitive_DT(t));
512450
if (ASR::is_a<ASR::List_t>(*t)) {

0 commit comments

Comments
 (0)