Skip to content

Initial setup for test_gruntz #2423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions integration_tests/test_gruntz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from lpython import S
from sympy import Symbol

def mrv(e: S, x: S) -> tuple[dict[S, S], S]:
if not e.has(x):
empty_dict : dict[S, S] = {}
return empty_dict, x
else:
raise

def test_mrv():
x: S = Symbol("x")
y: S = Symbol("y")
ans: tuple[dict[S, S], S] = mrv(y, x)

test_mrv()
17 changes: 10 additions & 7 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ namespace LCompilers {
static inline bool is_aggregate_or_array_type(ASR::expr_t* var) {
return (ASR::is_a<ASR::Struct_t>(*ASRUtils::expr_type(var)) ||
ASRUtils::is_array(ASRUtils::expr_type(var)) ||
ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(var)));
ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(var)) ||
ASR::is_a<ASR::Tuple_t>(*ASRUtils::expr_type(var)));
}

template <class Struct>
Expand Down Expand Up @@ -776,7 +777,7 @@ namespace LCompilers {
}

static inline void handle_fn_return_var(Allocator &al, ASR::Function_t *x,
bool (*is_array_or_struct_or_symbolic)(ASR::expr_t*)) {
bool (*is_array_or_struct_or_symbolic_or_tuple)(ASR::expr_t*)) {
if (ASRUtils::get_FunctionType(x)->m_abi == ASR::abiType::BindPython) {
return;
}
Expand All @@ -788,7 +789,7 @@ namespace LCompilers {
* in avoiding deep copies and the destination memory directly gets
* filled inside the function.
*/
if( is_array_or_struct_or_symbolic(x->m_return_var)) {
if( is_array_or_struct_or_symbolic_or_tuple(x->m_return_var)) {
for( auto& s_item: x->m_symtab->get_scope() ) {
ASR::symbol_t* curr_sym = s_item.second;
if( curr_sym->type == ASR::symbolType::Variable ) {
Expand Down Expand Up @@ -824,9 +825,11 @@ namespace LCompilers {
s_func_type->m_return_var_type = nullptr;

Vec<ASR::stmt_t*> func_body;
func_body.reserve(al, x->n_body - 1);
for (size_t i=0; i< x->n_body - 1; i++) {
func_body.push_back(al, x->m_body[i]);
func_body.reserve(al, x->n_body);
for (size_t i=0; i< x->n_body; i++) {
if (!ASR::is_a<ASR::Return_t>(*x->m_body[i])) {
func_body.push_back(al, x->m_body[i]);
}
}
x->m_body = func_body.p;
x->n_body = func_body.n;
Expand All @@ -835,7 +838,7 @@ namespace LCompilers {
for (auto &item : x->m_symtab->get_scope()) {
if (ASR::is_a<ASR::Function_t>(*item.second)) {
handle_fn_return_var(al, ASR::down_cast<ASR::Function_t>(
item.second), is_array_or_struct_or_symbolic);
item.second), is_array_or_struct_or_symbolic_or_tuple);
}
}
}
Expand Down
130 changes: 128 additions & 2 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,35 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

ASR::ttype_t* f_signature= xx.m_function_signature;
ASR::FunctionType_t *f_type = ASR::down_cast<ASR::FunctionType_t>(f_signature);
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
for (size_t i = 0; i < f_type->n_arg_types; ++i) {
if (f_type->m_arg_types[i]->type == ASR::ttypeType::SymbolicExpression) {
f_type->m_arg_types[i] = type1;
f_type->m_arg_types[i] = CPtr_type;
} else if (f_type->m_arg_types[i]->type == ASR::ttypeType::Tuple) {
Vec<ASR::ttype_t*> tuple_type_vec;
ASR::Tuple_t* tuple = ASR::down_cast<ASR::Tuple_t>(f_type->m_arg_types[i]);
tuple_type_vec.reserve(al, tuple->n_type);
for( size_t i = 0; i < tuple->n_type; i++ ) {
if (tuple->m_type[i]->type == ASR::ttypeType::SymbolicExpression) {
tuple_type_vec.push_back(al, CPtr_type);
} else if (tuple->m_type[i]->type == ASR::ttypeType::Dict) {
ASR::Dict_t *dict = ASR::down_cast<ASR::Dict_t>(tuple->m_type[i]);
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, xx.base.base.loc, key_type, value_type));
tuple_type_vec.push_back(al, dict_type);
} else {
tuple_type_vec.push_back(al, tuple->m_type[i]);
}
}
ASR::ttype_t* tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, xx.base.base.loc, tuple_type_vec.p, tuple_type_vec.n));
f_type->m_arg_types[i] = tuple_type;
}
}

Expand Down Expand Up @@ -256,6 +281,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, xx.base.base.loc, CPtr_type));
xx.m_type = list_type;
}
} else if (xx.m_type->type == ASR::ttypeType::Tuple) {
Vec<ASR::ttype_t*> tuple_type_vec;
ASR::Tuple_t* tuple = ASR::down_cast<ASR::Tuple_t>(xx.m_type);
tuple_type_vec.reserve(al, tuple->n_type);
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
for( size_t i = 0; i < tuple->n_type; i++ ) {
if (tuple->m_type[i]->type == ASR::ttypeType::SymbolicExpression) {
tuple_type_vec.push_back(al, CPtr_type);
} else if (tuple->m_type[i]->type == ASR::ttypeType::Dict) {
ASR::Dict_t *dict = ASR::down_cast<ASR::Dict_t>(tuple->m_type[i]);
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, xx.base.base.loc, key_type, value_type));
tuple_type_vec.push_back(al, dict_type);
} else {
tuple_type_vec.push_back(al, tuple->m_type[i]);
}
}
ASR::ttype_t* tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, xx.base.base.loc, tuple_type_vec.p, tuple_type_vec.n));
xx.m_type = tuple_type;
} else if (xx.m_type->type == ASR::ttypeType::Dict) {
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
ASR::Dict_t *dict = ASR::down_cast<ASR::Dict_t>(xx.m_type);
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, xx.base.base.loc, key_type, value_type));
xx.m_type = dict_type;
}
}

Expand Down Expand Up @@ -1374,6 +1438,57 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
pass_result.push_back(al, stmt);
}
} else if (ASR::is_a<ASR::TupleConstant_t>(*x.m_value)) {
ASR::TupleConstant_t* tuple_constant = ASR::down_cast<ASR::TupleConstant_t>(x.m_value);
if (tuple_constant->m_type->type == ASR::ttypeType::Tuple) {
ASR::Tuple_t* tuple = ASR::down_cast<ASR::Tuple_t>(tuple_constant->m_type);
Vec<ASR::ttype_t*> tuple_type_vec;
tuple_type_vec.reserve(al, tuple->n_type);
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
for( size_t i = 0; i < tuple->n_type; i++ ) {
if (tuple->m_type[i]->type == ASR::ttypeType::SymbolicExpression) {
tuple_type_vec.push_back(al, CPtr_type);
} else if (tuple->m_type[i]->type == ASR::ttypeType::Dict) {
ASR::Dict_t *dict = ASR::down_cast<ASR::Dict_t>(tuple->m_type[i]);
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, x.base.base.loc, key_type, value_type));
tuple_type_vec.push_back(al, dict_type);
} else {
tuple_type_vec.push_back(al, tuple->m_type[i]);
}
}
ASR::ttype_t* tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, x.base.base.loc, tuple_type_vec.p, tuple_type_vec.n));
ASR::expr_t* temp_tuple_const = ASRUtils::EXPR(ASR::make_TupleConstant_t(al, x.base.base.loc, tuple_constant->m_elements,
tuple_constant->n_elements, tuple_type));
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_tuple_const, nullptr));
pass_result.push_back(al, stmt);
}
} else if (ASR::is_a<ASR::DictConstant_t>(*x.m_value)) {
ASR::DictConstant_t* dict_constant = ASR::down_cast<ASR::DictConstant_t>(x.m_value);
if (dict_constant->m_type->type == ASR::ttypeType::Dict) {
ASR::Dict_t* dict = ASR::down_cast<ASR::Dict_t>(dict_constant->m_type);
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::ttype_t *key_type = dict->m_key_type;
ASR::ttype_t *value_type = dict->m_value_type;
if (key_type->type == ASR::ttypeType::SymbolicExpression) {
key_type = CPtr_type;
}
if (value_type->type == ASR::ttypeType::SymbolicExpression) {
value_type = CPtr_type;
}
ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, x.base.base.loc, key_type, value_type));
ASR::expr_t* temp_dict_const = ASRUtils::EXPR(ASR::make_DictConstant_t(al, x.base.base.loc, dict_constant->m_keys,
dict_constant->n_keys, dict_constant->m_values, dict_constant->n_values, dict_type));
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_dict_const, nullptr));
pass_result.push_back(al, stmt);
}
}
}

Expand All @@ -1388,6 +1503,17 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, xx.m_test, module_scope);
xx.m_test = function_call;
}
} else if (ASR::is_a<ASR::LogicalNot_t>(*xx.m_test)) {
ASR::LogicalNot_t* logical_not = ASR::down_cast<ASR::LogicalNot_t>(xx.m_test);
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*logical_not->m_arg)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(logical_not->m_arg);
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, logical_not->m_arg, module_scope);
ASR::expr_t* new_logical_not = ASRUtils::EXPR(ASR::make_LogicalNot_t(al, xx.base.base.loc, function_call,
logical_not->m_type, logical_not->m_value));
xx.m_test = new_logical_not;
}
}
}
}

Expand Down