diff --git a/integration_tests/test_gruntz.py b/integration_tests/test_gruntz.py new file mode 100644 index 0000000000..8c9565e3e4 --- /dev/null +++ b/integration_tests/test_gruntz.py @@ -0,0 +1,16 @@ +from lpython import S +from sympy import Symbol + +def mrv(e: S, x: S) -> tuple[dict[S, S], S]: + if not e.has(x): + empty_dict : dict[S, S] = {} + return empty_dict, x + else: + raise + +def test_mrv(): + x: S = Symbol("x") + y: S = Symbol("y") + ans: tuple[dict[S, S], S] = mrv(y, x) + +test_mrv() diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index 27b5ee5bb8..29c1b5eab1 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -117,7 +117,8 @@ namespace LCompilers { static inline bool is_aggregate_or_array_type(ASR::expr_t* var) { return (ASR::is_a(*ASRUtils::expr_type(var)) || ASRUtils::is_array(ASRUtils::expr_type(var)) || - ASR::is_a(*ASRUtils::expr_type(var))); + ASR::is_a(*ASRUtils::expr_type(var)) || + ASR::is_a(*ASRUtils::expr_type(var))); } template @@ -776,7 +777,7 @@ namespace LCompilers { } static inline void handle_fn_return_var(Allocator &al, ASR::Function_t *x, - bool (*is_array_or_struct_or_symbolic)(ASR::expr_t*)) { + bool (*is_array_or_struct_or_symbolic_or_tuple)(ASR::expr_t*)) { if (ASRUtils::get_FunctionType(x)->m_abi == ASR::abiType::BindPython) { return; } @@ -788,7 +789,7 @@ namespace LCompilers { * in avoiding deep copies and the destination memory directly gets * filled inside the function. */ - if( is_array_or_struct_or_symbolic(x->m_return_var)) { + if( is_array_or_struct_or_symbolic_or_tuple(x->m_return_var)) { for( auto& s_item: x->m_symtab->get_scope() ) { ASR::symbol_t* curr_sym = s_item.second; if( curr_sym->type == ASR::symbolType::Variable ) { @@ -824,9 +825,11 @@ namespace LCompilers { s_func_type->m_return_var_type = nullptr; Vec func_body; - func_body.reserve(al, x->n_body - 1); - for (size_t i=0; i< x->n_body - 1; i++) { - func_body.push_back(al, x->m_body[i]); + func_body.reserve(al, x->n_body); + for (size_t i=0; i< x->n_body; i++) { + if (!ASR::is_a(*x->m_body[i])) { + func_body.push_back(al, x->m_body[i]); + } } x->m_body = func_body.p; x->n_body = func_body.n; @@ -835,7 +838,7 @@ namespace LCompilers { for (auto &item : x->m_symtab->get_scope()) { if (ASR::is_a(*item.second)) { handle_fn_return_var(al, ASR::down_cast( - item.second), is_array_or_struct_or_symbolic); + item.second), is_array_or_struct_or_symbolic_or_tuple); } } } diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 7abf80f8fd..273c73b262 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -59,10 +59,35 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(f_signature); - ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); + ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); for (size_t i = 0; i < f_type->n_arg_types; ++i) { if (f_type->m_arg_types[i]->type == ASR::ttypeType::SymbolicExpression) { - f_type->m_arg_types[i] = type1; + f_type->m_arg_types[i] = CPtr_type; + } else if (f_type->m_arg_types[i]->type == ASR::ttypeType::Tuple) { + Vec tuple_type_vec; + ASR::Tuple_t* tuple = ASR::down_cast(f_type->m_arg_types[i]); + tuple_type_vec.reserve(al, tuple->n_type); + for( size_t i = 0; i < tuple->n_type; i++ ) { + if (tuple->m_type[i]->type == ASR::ttypeType::SymbolicExpression) { + tuple_type_vec.push_back(al, CPtr_type); + } else if (tuple->m_type[i]->type == ASR::ttypeType::Dict) { + ASR::Dict_t *dict = ASR::down_cast(tuple->m_type[i]); + ASR::ttype_t *key_type = dict->m_key_type; + ASR::ttype_t *value_type = dict->m_value_type; + if (key_type->type == ASR::ttypeType::SymbolicExpression) { + key_type = CPtr_type; + } + if (value_type->type == ASR::ttypeType::SymbolicExpression) { + value_type = CPtr_type; + } + ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, xx.base.base.loc, key_type, value_type)); + tuple_type_vec.push_back(al, dict_type); + } else { + tuple_type_vec.push_back(al, tuple->m_type[i]); + } + } + ASR::ttype_t* tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, xx.base.base.loc, tuple_type_vec.p, tuple_type_vec.n)); + f_type->m_arg_types[i] = tuple_type; } } @@ -256,6 +281,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitortype == ASR::ttypeType::Tuple) { + Vec tuple_type_vec; + ASR::Tuple_t* tuple = ASR::down_cast(xx.m_type); + tuple_type_vec.reserve(al, tuple->n_type); + ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); + for( size_t i = 0; i < tuple->n_type; i++ ) { + if (tuple->m_type[i]->type == ASR::ttypeType::SymbolicExpression) { + tuple_type_vec.push_back(al, CPtr_type); + } else if (tuple->m_type[i]->type == ASR::ttypeType::Dict) { + ASR::Dict_t *dict = ASR::down_cast(tuple->m_type[i]); + ASR::ttype_t *key_type = dict->m_key_type; + ASR::ttype_t *value_type = dict->m_value_type; + if (key_type->type == ASR::ttypeType::SymbolicExpression) { + key_type = CPtr_type; + } + if (value_type->type == ASR::ttypeType::SymbolicExpression) { + value_type = CPtr_type; + } + ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, xx.base.base.loc, key_type, value_type)); + tuple_type_vec.push_back(al, dict_type); + } else { + tuple_type_vec.push_back(al, tuple->m_type[i]); + } + } + ASR::ttype_t* tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, xx.base.base.loc, tuple_type_vec.p, tuple_type_vec.n)); + xx.m_type = tuple_type; + } else if (xx.m_type->type == ASR::ttypeType::Dict) { + ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); + ASR::Dict_t *dict = ASR::down_cast(xx.m_type); + ASR::ttype_t *key_type = dict->m_key_type; + ASR::ttype_t *value_type = dict->m_value_type; + if (key_type->type == ASR::ttypeType::SymbolicExpression) { + key_type = CPtr_type; + } + if (value_type->type == ASR::ttypeType::SymbolicExpression) { + value_type = CPtr_type; + } + ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, xx.base.base.loc, key_type, value_type)); + xx.m_type = dict_type; } } @@ -1374,6 +1438,57 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value)) { + ASR::TupleConstant_t* tuple_constant = ASR::down_cast(x.m_value); + if (tuple_constant->m_type->type == ASR::ttypeType::Tuple) { + ASR::Tuple_t* tuple = ASR::down_cast(tuple_constant->m_type); + Vec tuple_type_vec; + tuple_type_vec.reserve(al, tuple->n_type); + ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)); + for( size_t i = 0; i < tuple->n_type; i++ ) { + if (tuple->m_type[i]->type == ASR::ttypeType::SymbolicExpression) { + tuple_type_vec.push_back(al, CPtr_type); + } else if (tuple->m_type[i]->type == ASR::ttypeType::Dict) { + ASR::Dict_t *dict = ASR::down_cast(tuple->m_type[i]); + ASR::ttype_t *key_type = dict->m_key_type; + ASR::ttype_t *value_type = dict->m_value_type; + if (key_type->type == ASR::ttypeType::SymbolicExpression) { + key_type = CPtr_type; + } + if (value_type->type == ASR::ttypeType::SymbolicExpression) { + value_type = CPtr_type; + } + ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, x.base.base.loc, key_type, value_type)); + tuple_type_vec.push_back(al, dict_type); + } else { + tuple_type_vec.push_back(al, tuple->m_type[i]); + } + } + ASR::ttype_t* tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, x.base.base.loc, tuple_type_vec.p, tuple_type_vec.n)); + ASR::expr_t* temp_tuple_const = ASRUtils::EXPR(ASR::make_TupleConstant_t(al, x.base.base.loc, tuple_constant->m_elements, + tuple_constant->n_elements, tuple_type)); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_tuple_const, nullptr)); + pass_result.push_back(al, stmt); + } + } else if (ASR::is_a(*x.m_value)) { + ASR::DictConstant_t* dict_constant = ASR::down_cast(x.m_value); + if (dict_constant->m_type->type == ASR::ttypeType::Dict) { + ASR::Dict_t* dict = ASR::down_cast(dict_constant->m_type); + ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)); + ASR::ttype_t *key_type = dict->m_key_type; + ASR::ttype_t *value_type = dict->m_value_type; + if (key_type->type == ASR::ttypeType::SymbolicExpression) { + key_type = CPtr_type; + } + if (value_type->type == ASR::ttypeType::SymbolicExpression) { + value_type = CPtr_type; + } + ASR::ttype_t* dict_type = ASRUtils::TYPE(ASR::make_Dict_t(al, x.base.base.loc, key_type, value_type)); + ASR::expr_t* temp_dict_const = ASRUtils::EXPR(ASR::make_DictConstant_t(al, x.base.base.loc, dict_constant->m_keys, + dict_constant->n_keys, dict_constant->m_values, dict_constant->n_values, dict_type)); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_dict_const, nullptr)); + pass_result.push_back(al, stmt); + } } } @@ -1388,6 +1503,17 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*xx.m_test)) { + ASR::LogicalNot_t* logical_not = ASR::down_cast(xx.m_test); + if (ASR::is_a(*logical_not->m_arg)) { + ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(logical_not->m_arg); + if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { + ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, logical_not->m_arg, module_scope); + ASR::expr_t* new_logical_not = ASRUtils::EXPR(ASR::make_LogicalNot_t(al, xx.base.base.loc, function_call, + logical_not->m_type, logical_not->m_value)); + xx.m_test = new_logical_not; + } + } } }