Skip to content

Commit 39e1a26

Browse files
authored
Merge pull request #2094 from anutosh491/GSoC_PR6
Added support for symbolic elementary functions
2 parents 5334944 + 486db54 commit 39e1a26

File tree

5 files changed

+176
-63
lines changed

5 files changed

+176
-63
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ RUN(NAME symbolics_02 LABELS cpython_sym c_sym)
605605
RUN(NAME symbolics_03 LABELS cpython_sym c_sym)
606606
RUN(NAME symbolics_04 LABELS cpython_sym c_sym)
607607
RUN(NAME symbolics_05 LABELS cpython_sym c_sym)
608+
RUN(NAME symbolics_06 LABELS cpython_sym c_sym)
608609

609610
RUN(NAME sizeof_01 LABELS llvm c
610611
EXTRAFILES sizeof_01b.c)

integration_tests/symbolics_06.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from sympy import Symbol, sin, cos, exp, log, Abs, pi, diff
2+
from lpython import S
3+
4+
def test_elementary_functions():
5+
6+
# test sin, cos
7+
x: S = Symbol('x')
8+
assert(sin(pi) == S(0))
9+
assert(sin(pi/S(2)) == S(1))
10+
assert(sin(S(2)*pi) == S(0))
11+
assert(cos(pi) == S(-1))
12+
assert(cos(pi/S(2)) == S(0))
13+
assert(cos(S(2)*pi) == S(1))
14+
assert(diff(sin(x), x) == cos(x))
15+
assert(diff(cos(x), x) == S(-1)*sin(x))
16+
17+
# test exp, log
18+
assert(exp(S(0)) == S(1))
19+
assert(log(S(1)) == S(0))
20+
assert(diff(exp(x), x) == exp(x))
21+
assert(diff(log(x), x) == S(1)/x)
22+
23+
# test Abs
24+
assert(Abs(S(-10)) == S(10))
25+
assert(Abs(S(10)) == S(10))
26+
assert(Abs(S(-1)*x) == Abs(x))
27+
28+
# test composite functions
29+
a: S = exp(x)
30+
b: S = sin(a)
31+
c: S = cos(b)
32+
d: S = log(c)
33+
e: S = Abs(d)
34+
print(e)
35+
assert(e == Abs(log(cos(sin(exp(x))))))
36+
37+
test_elementary_functions()

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
27022702
out += func_name; break; \
27032703
}
27042704

2705-
std::string performSymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
2705+
std::string performBinarySymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
27062706
headers.insert("symengine/cwrapper.h");
27072707
std::string indent(4, ' ');
27082708
LCOMPILERS_ASSERT(x.n_args == 2);
@@ -2727,6 +2727,23 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
27272727
return target;
27282728
}
27292729

2730+
std::string performUnarySymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
2731+
headers.insert("symengine/cwrapper.h");
2732+
std::string indent(4, ' ');
2733+
LCOMPILERS_ASSERT(x.n_args == 1);
2734+
std::string target = symengine_queue.push();
2735+
std::string target_src = symengine_src;
2736+
this->visit_expr(*x.m_args[0]);
2737+
std::string arg1 = src;
2738+
std::string arg1_src = symengine_src;
2739+
if (ASR::is_a<ASR::Var_t>(*x.m_args[0])) {
2740+
symengine_queue.pop();
2741+
}
2742+
symengine_src = target_src + arg1_src;
2743+
symengine_src += indent + functionName + "(" + target + ", " + arg1 + ");\n";
2744+
return target;
2745+
}
2746+
27302747
void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t &x) {
27312748
std::string out;
27322749
std::string indent(4, ' ');
@@ -2745,27 +2762,51 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
27452762
SET_INTRINSIC_NAME(Exp2, "exp2");
27462763
SET_INTRINSIC_NAME(Expm1, "expm1");
27472764
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)): {
2748-
src = performSymbolicOperation("basic_add", x);
2765+
src = performBinarySymbolicOperation("basic_add", x);
27492766
return;
27502767
}
27512768
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub)): {
2752-
src = performSymbolicOperation("basic_sub", x);
2769+
src = performBinarySymbolicOperation("basic_sub", x);
27532770
return;
27542771
}
27552772
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul)): {
2756-
src = performSymbolicOperation("basic_mul", x);
2773+
src = performBinarySymbolicOperation("basic_mul", x);
27572774
return;
27582775
}
27592776
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv)): {
2760-
src = performSymbolicOperation("basic_div", x);
2777+
src = performBinarySymbolicOperation("basic_div", x);
27612778
return;
27622779
}
27632780
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPow)): {
2764-
src = performSymbolicOperation("basic_pow", x);
2781+
src = performBinarySymbolicOperation("basic_pow", x);
27652782
return;
27662783
}
27672784
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiff)): {
2768-
src = performSymbolicOperation("basic_diff", x);
2785+
src = performBinarySymbolicOperation("basic_diff", x);
2786+
return;
2787+
}
2788+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSin)): {
2789+
src = performUnarySymbolicOperation("basic_sin", x);
2790+
return;
2791+
}
2792+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicCos)): {
2793+
src = performUnarySymbolicOperation("basic_cos", x);
2794+
return;
2795+
}
2796+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicLog)): {
2797+
src = performUnarySymbolicOperation("basic_log", x);
2798+
return;
2799+
}
2800+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExp)): {
2801+
src = performUnarySymbolicOperation("basic_exp", x);
2802+
return;
2803+
}
2804+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAbs)): {
2805+
src = performUnarySymbolicOperation("basic_abs", x);
2806+
return;
2807+
}
2808+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand)): {
2809+
src = performUnarySymbolicOperation("basic_expand", x);
27692810
return;
27702811
}
27712812
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi)): {
@@ -2794,22 +2835,6 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
27942835
src = target;
27952836
return;
27962837
}
2797-
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand)): {
2798-
headers.insert("symengine/cwrapper.h");
2799-
LCOMPILERS_ASSERT(x.n_args == 1);
2800-
std::string target = symengine_queue.push();
2801-
std::string target_src = symengine_src;
2802-
this->visit_expr(*x.m_args[0]);
2803-
std::string arg1 = src;
2804-
std::string arg1_src = symengine_src;
2805-
if (ASR::is_a<ASR::Var_t>(*x.m_args[0])) {
2806-
symengine_queue.pop();
2807-
}
2808-
symengine_src = target_src + arg1_src;
2809-
symengine_src += indent + "basic_expand(" + target + ", " + arg1 + ");\n";
2810-
src = target;
2811-
return;
2812-
}
28132838
default : {
28142839
throw LCompilersException("IntrinsicFunction: `"
28152840
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 80 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ enum class IntrinsicFunctions : int64_t {
7474
SymbolicInteger,
7575
SymbolicDiff,
7676
SymbolicExpand,
77+
SymbolicSin,
78+
SymbolicCos,
79+
SymbolicLog,
80+
SymbolicExp,
81+
SymbolicAbs,
7782
Sum,
7883
// ...
7984
};
@@ -2169,45 +2174,52 @@ namespace SymbolicInteger {
21692174
}
21702175
} // namespace SymbolicInteger
21712176

2172-
namespace SymbolicExpand {
2173-
2174-
static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) {
2175-
const Location& loc = x.base.base.loc;
2176-
ASRUtils::require_impl(x.n_args == 1,
2177-
"SymbolicExpand must have exactly 1 input argument",
2178-
loc, diagnostics);
2179-
2180-
ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]);
2181-
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*input_type),
2182-
"SymbolicExpand expects an argument of type SymbolicExpression",
2183-
x.base.base.loc, diagnostics);
2184-
}
2185-
2186-
static inline ASR::expr_t *eval_SymbolicExpand(Allocator &/*al*/,
2187-
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
2188-
// TODO
2189-
return nullptr;
2190-
}
2191-
2192-
static inline ASR::asr_t* create_SymbolicExpand(Allocator& al, const Location& loc,
2193-
Vec<ASR::expr_t*>& args,
2194-
const std::function<void (const std::string &, const Location &)> err) {
2195-
if (args.size() != 1) {
2196-
err("Intrinsic expand function accepts exactly 1 argument", loc);
2197-
}
2198-
2199-
ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]);
2200-
if(!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) {
2201-
err("Argument of SymbolicExpand function must be of type SymbolicExpression",
2202-
args[0]->base.loc);
2203-
}
2204-
2205-
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
2206-
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicExpand,
2207-
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand), 0, to_type);
2208-
}
2177+
#define create_symbolic_unary_macro(X) \
2178+
namespace X { \
2179+
\
2180+
static inline void verify_args(const ASR::IntrinsicFunction_t& x, \
2181+
diag::Diagnostics& diagnostics) { \
2182+
const Location& loc = x.base.base.loc; \
2183+
ASRUtils::require_impl(x.n_args == 1, \
2184+
#X " must have exactly 1 input argument", loc, diagnostics); \
2185+
\
2186+
ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \
2187+
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*input_type), \
2188+
#X " expects an argument of type SymbolicExpression", loc, diagnostics); \
2189+
} \
2190+
\
2191+
static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \
2192+
Vec<ASR::expr_t*> &/*args*/) { \
2193+
/*TODO*/ \
2194+
return nullptr; \
2195+
} \
2196+
\
2197+
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
2198+
Vec<ASR::expr_t*>& args, \
2199+
const std::function<void (const std::string &, const Location &)> err) { \
2200+
if (args.size() != 1) { \
2201+
err("Intrinsic " #X " function accepts exactly 1 argument", loc); \
2202+
} \
2203+
\
2204+
ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); \
2205+
if (!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) { \
2206+
err("Argument of " #X " function must be of type SymbolicExpression", \
2207+
args[0]->base.loc); \
2208+
} \
2209+
\
2210+
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); \
2211+
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_##X, \
2212+
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::X), 0, to_type); \
2213+
} \
2214+
\
2215+
} // namespace X
22092216

2210-
} // namespace SymbolicExpand
2217+
create_symbolic_unary_macro(SymbolicSin)
2218+
create_symbolic_unary_macro(SymbolicCos)
2219+
create_symbolic_unary_macro(SymbolicLog)
2220+
create_symbolic_unary_macro(SymbolicExp)
2221+
create_symbolic_unary_macro(SymbolicAbs)
2222+
create_symbolic_unary_macro(SymbolicExpand)
22112223

22122224
namespace IntrinsicFunctionRegistry {
22132225

@@ -2275,6 +2287,16 @@ namespace IntrinsicFunctionRegistry {
22752287
{nullptr, &SymbolicDiff::verify_args}},
22762288
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand),
22772289
{nullptr, &SymbolicExpand::verify_args}},
2290+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSin),
2291+
{nullptr, &SymbolicSin::verify_args}},
2292+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicCos),
2293+
{nullptr, &SymbolicCos::verify_args}},
2294+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicLog),
2295+
{nullptr, &SymbolicLog::verify_args}},
2296+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExp),
2297+
{nullptr, &SymbolicExp::verify_args}},
2298+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAbs),
2299+
{nullptr, &SymbolicAbs::verify_args}},
22782300
};
22792301

22802302
static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
@@ -2333,6 +2355,16 @@ namespace IntrinsicFunctionRegistry {
23332355
"SymbolicDiff"},
23342356
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand),
23352357
"SymbolicExpand"},
2358+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSin),
2359+
"SymbolicSin"},
2360+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicCos),
2361+
"SymbolicCos"},
2362+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicLog),
2363+
"SymbolicLog"},
2364+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExp),
2365+
"SymbolicExp"},
2366+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAbs),
2367+
"SymbolicAbs"},
23362368
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Any),
23372369
"any"},
23382370
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Sum),
@@ -2372,6 +2404,11 @@ namespace IntrinsicFunctionRegistry {
23722404
{"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}},
23732405
{"diff", {&SymbolicDiff::create_SymbolicDiff, &SymbolicDiff::eval_SymbolicDiff}},
23742406
{"expand", {&SymbolicExpand::create_SymbolicExpand, &SymbolicExpand::eval_SymbolicExpand}},
2407+
{"SymbolicSin", {&SymbolicSin::create_SymbolicSin, &SymbolicSin::eval_SymbolicSin}},
2408+
{"SymbolicCos", {&SymbolicCos::create_SymbolicCos, &SymbolicCos::eval_SymbolicCos}},
2409+
{"SymbolicLog", {&SymbolicLog::create_SymbolicLog, &SymbolicLog::eval_SymbolicLog}},
2410+
{"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}},
2411+
{"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}},
23752412
};
23762413

23772414
static inline bool is_intrinsic_function(const std::string& name) {
@@ -2488,6 +2525,11 @@ inline std::string get_intrinsic_name(int x) {
24882525
INTRINSIC_NAME_CASE(SymbolicInteger)
24892526
INTRINSIC_NAME_CASE(SymbolicDiff)
24902527
INTRINSIC_NAME_CASE(SymbolicExpand)
2528+
INTRINSIC_NAME_CASE(SymbolicSin)
2529+
INTRINSIC_NAME_CASE(SymbolicCos)
2530+
INTRINSIC_NAME_CASE(SymbolicLog)
2531+
INTRINSIC_NAME_CASE(SymbolicExp)
2532+
INTRINSIC_NAME_CASE(SymbolicAbs)
24912533
INTRINSIC_NAME_CASE(Sum)
24922534
default : {
24932535
throw LCompilersException("pickle: intrinsic_id not implemented");

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7265,15 +7265,23 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
72657265
}
72667266

72677267
if (!s) {
7268+
std::string intrinsic_name = call_name;
72687269
std::set<std::string> not_cpython_builtin = {
72697270
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand",
72707271
"sum" // For sum called over lists
72717272
};
7272-
if (ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(call_name) &&
7273+
std::set<std::string> symbolic_functions = {
7274+
"sin", "cos", "log", "exp", "Abs"
7275+
};
7276+
if ((symbolic_functions.find(call_name) != symbolic_functions.end()) &&
7277+
imported_functions[call_name] == "sympy"){
7278+
intrinsic_name = "Symbolic" + std::string(1, std::toupper(call_name[0])) + call_name.substr(1);
7279+
}
7280+
if (ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(intrinsic_name) &&
72737281
(not_cpython_builtin.find(call_name) == not_cpython_builtin.end() ||
72747282
imported_functions.find(call_name) != imported_functions.end() )) {
72757283
ASRUtils::create_intrinsic_function create_func =
7276-
ASRUtils::IntrinsicFunctionRegistry::get_create_function(call_name);
7284+
ASRUtils::IntrinsicFunctionRegistry::get_create_function(intrinsic_name);
72777285
Vec<ASR::expr_t*> args_; args_.reserve(al, x.n_args);
72787286
visit_expr_list(x.m_args, x.n_args, args_);
72797287
if (ASRUtils::is_array(ASRUtils::expr_type(args_[0])) &&

0 commit comments

Comments
 (0)