Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ cast_kind
| RealToUnsignedInteger
| CPtrToUnsignedInteger
| UnsignedIntegerToCPtr
| IntegerToSymbolicExpression

dimension = (expr? start, expr? length)

Expand Down
10 changes: 10 additions & 0 deletions src/libasr/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <libasr/utils.h>
#include <libasr/modfile.h>
#include <libasr/pass/pass_utils.h>
#include <libasr/pass/intrinsic_function_registry.h>

namespace LCompilers {

Expand Down Expand Up @@ -1196,6 +1197,15 @@ ASR::asr_t* make_Cast_t_value(Allocator &al, const Location &a_loc,
double real = value_complex->m_re;
value = ASR::down_cast<ASR::expr_t>(
ASR::make_RealConstant_t(al, a_loc, real, a_type));
} else if (a_kind == ASR::cast_kindType::IntegerToSymbolicExpression) {
Vec<ASR::expr_t*> args;
args.reserve(al, 1);
args.push_back(al, a_arg);
LCompilers::ASRUtils::create_intrinsic_function create_function =
LCompilers::ASRUtils::IntrinsicFunctionRegistry::get_create_function("SymbolicInteger");
value = ASR::down_cast<ASR::expr_t>(create_function(al, a_loc, args,
[](const std::string&, const Location&) {
}));
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,9 @@ R"(#include <stdio.h>
last_expr_precedence = 2;
break;
}
case (ASR::cast_kindType::IntegerToSymbolicExpression): {
break;
}
default : throw CodeGenError("Cast kind " + std::to_string(x.m_kind) + " not implemented",
x.base.base.loc);
}
Expand Down Expand Up @@ -2395,7 +2398,11 @@ R"(#include <stdio.h>
SET_INTRINSIC_NAME(Expm1, "expm1");
SET_INTRINSIC_NAME(SymbolicSymbol, "Symbol");
SET_INTRINSIC_NAME(SymbolicPi, "pi");
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)): {
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)):
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub)):
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul)):
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv)):
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPow)): {
LCOMPILERS_ASSERT(x.n_args == 2);
this->visit_expr(*x.m_args[0]);
std::string arg1 = src;
Expand Down
37 changes: 33 additions & 4 deletions src/libasr/codegen/c_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,14 @@ class CCPPDSUtils {
return result;
}

std::string generate_binary_operator_code(std::string value, std::string target, std::string operatorName) {
size_t delimiterPos = value.find(",");
std::string leftPart = value.substr(0, delimiterPos);
std::string rightPart = value.substr(delimiterPos + 1);
std::string result = operatorName + "(" + target + ", " + leftPart + ", " + rightPart + ");";
return result;
}

std::string get_deepcopy_symbolic(ASR::expr_t *value_expr, std::string value, std::string target) {
std::string result;
if (ASR::is_a<ASR::Var_t>(*value_expr)) {
Expand All @@ -645,22 +653,43 @@ class CCPPDSUtils {
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: {
size_t delimiterPos = value.find(",");
std::string leftPart = value.substr(0, delimiterPos);
std::string rightPart = value.substr(delimiterPos + 1);
result = "basic_add(" + target + ", " + leftPart + ", " + rightPart + ");";
result = generate_binary_operator_code(value, target, "basic_add");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: {
result = generate_binary_operator_code(value, target, "basic_sub");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: {
result = generate_binary_operator_code(value, target, "basic_mul");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: {
result = generate_binary_operator_code(value, target, "basic_div");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: {
result = generate_binary_operator_code(value, target, "basic_pow");
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: {
result = "basic_const_pi(" + target + ");";
break;
}
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicInteger: {
result = "integer_set_si(" + target + ", " + value + ");";
break;
}
default: {
throw LCompilersException("IntrinsicFunction: `"
+ LCompilers::ASRUtils::get_intrinsic_name(intrinsic_id)
+ "` is not implemented");
}
}
} else if (ASR::is_a<ASR::Cast_t>(*value_expr)) {
ASR::Cast_t* cast_expr = ASR::down_cast<ASR::Cast_t>(value_expr);
std::string cast_value_expr = get_deepcopy_symbolic(cast_expr->m_value, value, target);
return cast_value_expr;
}
return result;
}
Expand Down
158 changes: 123 additions & 35 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ enum class IntrinsicFunctions : int64_t {
ListPop,
SymbolicSymbol,
SymbolicAdd,
SymbolicSub,
SymbolicMul,
SymbolicDiv,
SymbolicPow,
SymbolicPi,
SymbolicInteger,
Sum,
// ...
};
Expand Down Expand Up @@ -2037,72 +2042,125 @@ namespace SymbolicSymbol {

} // namespace SymbolicSymbol

namespace SymbolicAdd {
#define create_symbolic_binop_macro(X) \
namespace X{ \
\
static inline void verify_args(const ASR::IntrinsicFunction_t& x, \
diag::Diagnostics& diagnostics) { \
ASRUtils::require_impl(x.n_args == 2, "Intrinsic function `"#X"` accepts \
exactly 2 arguments", x.base.base.loc, diagnostics); \
\
ASR::ttype_t* left_type = ASRUtils::expr_type(x.m_args[0]); \
ASR::ttype_t* right_type = ASRUtils::expr_type(x.m_args[1]); \
\
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*left_type) && \
ASR::is_a<ASR::SymbolicExpression_t>(*right_type), \
"Both arguments of `"#X"` must be of type SymbolicExpression", \
x.base.base.loc, diagnostics); \
} \
\
static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \
Vec<ASR::expr_t*> &/*args*/) { \
/*TODO*/ \
return nullptr; \
} \
\
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
Vec<ASR::expr_t*>& args, \
const std::function<void (const std::string &, const Location &)> err) { \
if (args.size() != 2) { \
err("Intrinsic function `"#X"` accepts exactly 2 arguments", loc); \
} \
\
for (size_t i = 0; i < args.size(); i++) { \
ASR::ttype_t* argtype = ASRUtils::expr_type(args[i]); \
if(!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) { \
err("Arguments of `"#X"` function must be of type SymbolicExpression", \
args[i]->base.loc); \
} \
} \
\
Vec<ASR::expr_t*> arg_values; \
arg_values.reserve(al, args.size()); \
for( size_t i = 0; i < args.size(); i++ ) { \
arg_values.push_back(al, ASRUtils::expr_value(args[i])); \
} \
ASR::expr_t* compile_time_value = eval_##X(al, loc, arg_values); \
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); \
return ASR::make_IntrinsicFunction_t(al, loc, \
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::X), \
args.p, args.size(), 0, to_type, compile_time_value); \
} \
} // namespace X

static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 2, "SymbolicAdd must have exactly two arguments",
x.base.base.loc, diagnostics);
create_symbolic_binop_macro(SymbolicAdd)
create_symbolic_binop_macro(SymbolicSub)
create_symbolic_binop_macro(SymbolicMul)
create_symbolic_binop_macro(SymbolicDiv)
create_symbolic_binop_macro(SymbolicPow)

ASR::ttype_t* left_type = ASRUtils::expr_type(x.m_args[0]);
ASR::ttype_t* right_type = ASRUtils::expr_type(x.m_args[1]);
namespace SymbolicPi {

ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*left_type) &&
ASR::is_a<ASR::SymbolicExpression_t>(*right_type),
"Both arguments of SymbolicAdd must be of type SymbolicExpression",
static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 0, "SymbolicPi does not take arguments",
x.base.base.loc, diagnostics);
}

static inline ASR::expr_t *eval_SymbolicAdd(Allocator &/*al*/,
static inline ASR::expr_t *eval_SymbolicPi(Allocator &/*al*/,
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
// TODO
return nullptr;
}

static inline ASR::asr_t* create_SymbolicAdd(Allocator& al, const Location& loc,
static inline ASR::asr_t* create_SymbolicPi(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {
if (args.size() != 2) {
err("Intrinsic Symbol Add operator accepts exactly 2 arguments", loc);
}

Vec<ASR::expr_t*> arg_values;
arg_values.reserve(al, args.size());
for( size_t i = 0; i < args.size(); i++ ) {
arg_values.push_back(al, ASRUtils::expr_value(args[i]));
}
ASR::expr_t* compile_time_value = eval_SymbolicAdd(al, loc, arg_values);
const std::function<void (const std::string &, const Location &)> /*err*/) {
ASR::expr_t* compile_time_value = eval_SymbolicPi(al, loc, args);
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
return ASR::make_IntrinsicFunction_t(al, loc,
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd),
args.p, args.size(), 0, to_type, compile_time_value);
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi),
nullptr, 0, 0, to_type, compile_time_value);
}

} // namespace SymbolicAdd
} // namespace SymbolicPi

namespace SymbolicPi {
namespace SymbolicInteger {

static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 0, "SymbolicPi does not take arguments",
ASRUtils::require_impl(x.n_args == 1,
"SymbolicInteger intrinsic must have exactly 1 input argument",
x.base.base.loc, diagnostics);

ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]);
ASRUtils::require_impl(ASR::is_a<ASR::Integer_t>(*input_type),
"SymbolicInteger intrinsic expects an integer input argument",
x.base.base.loc, diagnostics);
}

static inline ASR::expr_t *eval_SymbolicPi(Allocator &/*al*/,
static inline ASR::expr_t* eval_SymbolicInteger(Allocator &/*al*/,
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
// TODO
return nullptr;
}

static inline ASR::asr_t* create_SymbolicPi(Allocator& al, const Location& loc,
static inline ASR::asr_t* create_SymbolicInteger(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> /*err*/) {
ASR::expr_t* compile_time_value = eval_SymbolicPi(al, loc, args);
// if (args.size() != 1) {
// err("Intrinsic SymbolicInteger function accepts exactly 1 argument", loc);
// }

// ASR::ttype_t* type = ASRUtils::expr_type(args[0]);
// if (!ASRUtils::is_integer(*type)) {
// err("Argument of the SymbolicInteger function must be an Integer",
// args[0]->base.loc);
// }

ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
return ASR::make_IntrinsicFunction_t(al, loc,
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi),
nullptr, 0, 0, to_type, compile_time_value);
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicInteger,
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger), 0, to_type);
}

} // namespace SymbolicPi
} // namespace SymbolicInteger

namespace IntrinsicFunctionRegistry {

Expand Down Expand Up @@ -2154,8 +2212,18 @@ namespace IntrinsicFunctionRegistry {
{nullptr, &SymbolicSymbol::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd),
{nullptr, &SymbolicAdd::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub),
{nullptr, &SymbolicSub::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul),
{nullptr, &SymbolicMul::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv),
{nullptr, &SymbolicDiv::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPow),
{nullptr, &SymbolicPow::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi),
{nullptr, &SymbolicPi::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger),
{nullptr, &SymbolicInteger::verify_args}},
};

static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
Expand Down Expand Up @@ -2198,8 +2266,18 @@ namespace IntrinsicFunctionRegistry {
"Symbol"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd),
"SymbolicAdd"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub),
"SymbolicSub"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul),
"SymbolicMul"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv),
"SymbolicDiv"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPow),
"SymbolicPow"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi),
"pi"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger),
"SymbolicInteger"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Any),
"any"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Sum),
Expand Down Expand Up @@ -2231,7 +2309,12 @@ namespace IntrinsicFunctionRegistry {
{"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}},
{"Symbol", {&SymbolicSymbol::create_SymbolicSymbol, &SymbolicSymbol::eval_SymbolicSymbol}},
{"SymbolicAdd", {&SymbolicAdd::create_SymbolicAdd, &SymbolicAdd::eval_SymbolicAdd}},
{"SymbolicSub", {&SymbolicSub::create_SymbolicSub, &SymbolicSub::eval_SymbolicSub}},
{"SymbolicMul", {&SymbolicMul::create_SymbolicMul, &SymbolicMul::eval_SymbolicMul}},
{"SymbolicDiv", {&SymbolicDiv::create_SymbolicDiv, &SymbolicDiv::eval_SymbolicDiv}},
{"SymbolicPow", {&SymbolicPow::create_SymbolicPow, &SymbolicPow::eval_SymbolicPow}},
{"pi", {&SymbolicPi::create_SymbolicPi, &SymbolicPi::eval_SymbolicPi}},
{"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}},
};

static inline bool is_intrinsic_function(const std::string& name) {
Expand Down Expand Up @@ -2340,7 +2423,12 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(ListPop)
INTRINSIC_NAME_CASE(SymbolicSymbol)
INTRINSIC_NAME_CASE(SymbolicAdd)
INTRINSIC_NAME_CASE(SymbolicSub)
INTRINSIC_NAME_CASE(SymbolicMul)
INTRINSIC_NAME_CASE(SymbolicDiv)
INTRINSIC_NAME_CASE(SymbolicPow)
INTRINSIC_NAME_CASE(SymbolicPi)
INTRINSIC_NAME_CASE(SymbolicInteger)
INTRINSIC_NAME_CASE(Sum)
default : {
throw LCompilersException("pickle: intrinsic_id not implemented");
Expand Down
Loading