Skip to content

Commit 37fa041

Browse files
authored
SymbolDuplicator for pass_array_by_data (#1660)
1 parent 3317fec commit 37fa041

6 files changed

+310
-101
lines changed

src/libasr/asdl_cpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,7 @@ def visitField(self, field):
999999
elif field.type == "alloc_arg":
10001000
self.emit(" ASR::alloc_arg_t alloc_arg_copy;", level)
10011001
self.emit(" alloc_arg_copy.loc = x->m_%s[i].loc;"%(field.name), level)
1002-
self.emit(" alloc_arg_copy.m_a = x->m_%s[i].m_a;"%(field.name), level)
1002+
self.emit(" alloc_arg_copy.m_a = self().duplicate_expr(x->m_%s[i].m_a);"%(field.name), level)
10031003
self.emit(" alloc_arg_copy.n_dims = x->m_%s[i].n_dims;"%(field.name), level)
10041004
self.emit(" Vec<ASR::dimension_t> dims_copy;", level)
10051005
self.emit(" dims_copy.reserve(al, alloc_arg_copy.n_dims);", level)

src/libasr/asr_utils.h

Lines changed: 226 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,7 +1539,7 @@ inline int extract_dimensions_from_ttype(ASR::ttype_t *x,
15391539
break;
15401540
}
15411541
default:
1542-
throw LCompilersException("Not implemented.");
1542+
throw LCompilersException("Not implemented " + std::to_string(x->type) + ".");
15431543
}
15441544
return n_dims;
15451545
}
@@ -2573,6 +2573,34 @@ class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {
25732573

25742574
};
25752575

2576+
inline ASR::asr_t* make_Function_t_util(Allocator& al, const Location& loc,
2577+
SymbolTable* m_symtab, char* m_name, char** m_dependencies, size_t n_dependencies,
2578+
ASR::expr_t** a_args, size_t n_args, ASR::stmt_t** m_body, size_t n_body,
2579+
ASR::expr_t* m_return_var, ASR::abiType m_abi, ASR::accessType m_access,
2580+
ASR::deftypeType m_deftype, char* m_bindc_name, bool m_elemental, bool m_pure,
2581+
bool m_module, bool m_inline, bool m_static, ASR::ttype_t** m_type_params,
2582+
size_t n_type_params, ASR::symbol_t** m_restrictions, size_t n_restrictions,
2583+
bool m_is_restriction, bool m_deterministic, bool m_side_effect_free) {
2584+
Vec<ASR::ttype_t*> arg_types;
2585+
arg_types.reserve(al, n_args);
2586+
for( size_t i = 0; i < n_args; i++ ) {
2587+
arg_types.push_back(al, ASRUtils::expr_type(a_args[i]));
2588+
}
2589+
ASR::ttype_t* return_var_type = nullptr;
2590+
if( m_return_var ) {
2591+
return_var_type = ASRUtils::expr_type(m_return_var);
2592+
}
2593+
ASR::ttype_t* func_type = ASRUtils::TYPE(ASR::make_FunctionType_t(
2594+
al, loc, arg_types.p, arg_types.size(), return_var_type, m_abi,
2595+
m_deftype, m_bindc_name, m_elemental, m_pure, m_module, m_inline,
2596+
m_static, m_type_params, n_type_params, m_restrictions, n_restrictions,
2597+
m_is_restriction));
2598+
return ASR::make_Function_t(
2599+
al, loc, m_symtab, m_name, func_type, m_dependencies, n_dependencies,
2600+
a_args, n_args, m_body, n_body, m_return_var, m_access, m_deterministic,
2601+
m_side_effect_free);
2602+
}
2603+
25762604
class ExprStmtDuplicator: public ASR::BaseExprStmtDuplicator<ExprStmtDuplicator>
25772605
{
25782606
public:
@@ -2581,6 +2609,199 @@ class ExprStmtDuplicator: public ASR::BaseExprStmtDuplicator<ExprStmtDuplicator>
25812609

25822610
};
25832611

2612+
class SymbolDuplicator {
2613+
2614+
private:
2615+
2616+
Allocator& al;
2617+
2618+
public:
2619+
2620+
SymbolDuplicator(Allocator& al_):
2621+
al(al_) {
2622+
2623+
}
2624+
2625+
void duplicate_SymbolTable(SymbolTable* symbol_table,
2626+
SymbolTable* destination_symtab) {
2627+
for( auto& item: symbol_table->get_scope() ) {
2628+
duplicate_symbol(item.second, destination_symtab);
2629+
}
2630+
}
2631+
2632+
void duplicate_symbol(ASR::symbol_t* symbol,
2633+
SymbolTable* destination_symtab) {
2634+
ASR::symbol_t* new_symbol = nullptr;
2635+
std::string new_symbol_name = "";
2636+
switch( symbol->type ) {
2637+
case ASR::symbolType::Variable: {
2638+
ASR::Variable_t* variable = ASR::down_cast<ASR::Variable_t>(symbol);
2639+
new_symbol = duplicate_Variable(variable, destination_symtab);
2640+
new_symbol_name = variable->m_name;
2641+
break;
2642+
}
2643+
case ASR::symbolType::ExternalSymbol: {
2644+
ASR::ExternalSymbol_t* external_symbol = ASR::down_cast<ASR::ExternalSymbol_t>(symbol);
2645+
new_symbol = duplicate_ExternalSymbol(external_symbol, destination_symtab);
2646+
new_symbol_name = external_symbol->m_name;
2647+
break;
2648+
}
2649+
case ASR::symbolType::AssociateBlock: {
2650+
ASR::AssociateBlock_t* associate_block = ASR::down_cast<ASR::AssociateBlock_t>(symbol);
2651+
new_symbol = duplicate_AssociateBlock(associate_block, destination_symtab);
2652+
new_symbol_name = associate_block->m_name;
2653+
break;
2654+
}
2655+
case ASR::symbolType::Function: {
2656+
ASR::Function_t* function = ASR::down_cast<ASR::Function_t>(symbol);
2657+
new_symbol = duplicate_Function(function, destination_symtab);
2658+
new_symbol_name = function->m_name;
2659+
break;
2660+
}
2661+
default: {
2662+
throw LCompilersException("Duplicating ASR::symbolType::" +
2663+
std::to_string(symbol->type) + " is not supported yet.");
2664+
}
2665+
}
2666+
if( new_symbol ) {
2667+
destination_symtab->add_symbol(new_symbol_name, new_symbol);
2668+
}
2669+
}
2670+
2671+
ASR::symbol_t* duplicate_Variable(ASR::Variable_t* variable,
2672+
SymbolTable* destination_symtab) {
2673+
ExprStmtDuplicator node_duplicator(al);
2674+
node_duplicator.success = true;
2675+
ASR::expr_t* m_symbolic_value = node_duplicator.duplicate_expr(variable->m_symbolic_value);
2676+
if( !node_duplicator.success ) {
2677+
return nullptr;
2678+
}
2679+
node_duplicator.success = true;
2680+
ASR::expr_t* m_value = node_duplicator.duplicate_expr(variable->m_value);
2681+
if( !node_duplicator.success ) {
2682+
return nullptr;
2683+
}
2684+
node_duplicator.success = true;
2685+
ASR::ttype_t* m_type = node_duplicator.duplicate_ttype(variable->m_type);
2686+
if( !node_duplicator.success ) {
2687+
return nullptr;
2688+
}
2689+
return ASR::down_cast<ASR::symbol_t>(
2690+
ASR::make_Variable_t(al, variable->base.base.loc, destination_symtab,
2691+
variable->m_name, variable->m_dependencies, variable->n_dependencies,
2692+
variable->m_intent, m_symbolic_value, m_value, variable->m_storage,
2693+
m_type, variable->m_abi, variable->m_access, variable->m_presence,
2694+
variable->m_value_attr));
2695+
}
2696+
2697+
ASR::symbol_t* duplicate_ExternalSymbol(ASR::ExternalSymbol_t* external_symbol,
2698+
SymbolTable* destination_symtab) {
2699+
return ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(
2700+
al, external_symbol->base.base.loc, destination_symtab,
2701+
external_symbol->m_name, external_symbol->m_external,
2702+
external_symbol->m_module_name, external_symbol->m_scope_names,
2703+
external_symbol->n_scope_names, external_symbol->m_original_name,
2704+
external_symbol->m_access));
2705+
}
2706+
2707+
ASR::symbol_t* duplicate_AssociateBlock(ASR::AssociateBlock_t* associate_block,
2708+
SymbolTable* destination_symtab) {
2709+
SymbolTable* associate_block_symtab = al.make_new<SymbolTable>(destination_symtab);
2710+
duplicate_SymbolTable(associate_block->m_symtab, associate_block_symtab);
2711+
Vec<ASR::stmt_t*> new_body;
2712+
new_body.reserve(al, associate_block->n_body);
2713+
ASRUtils::ExprStmtDuplicator node_duplicator(al);
2714+
node_duplicator.allow_procedure_calls = true;
2715+
node_duplicator.allow_reshape = false;
2716+
for( size_t i = 0; i < associate_block->n_body; i++ ) {
2717+
node_duplicator.success = true;
2718+
ASR::stmt_t* new_stmt = node_duplicator.duplicate_stmt(associate_block->m_body[i]);
2719+
if( !node_duplicator.success ) {
2720+
return nullptr;
2721+
}
2722+
new_body.push_back(al, new_stmt);
2723+
}
2724+
2725+
// node_duplicator_.allow_procedure_calls = true;
2726+
2727+
return ASR::down_cast<ASR::symbol_t>(ASR::make_AssociateBlock_t(al,
2728+
associate_block->base.base.loc, associate_block_symtab,
2729+
associate_block->m_name, new_body.p, new_body.size()));
2730+
}
2731+
2732+
ASR::symbol_t* duplicate_Function(ASR::Function_t* function,
2733+
SymbolTable* destination_symtab) {
2734+
SymbolTable* function_symtab = al.make_new<SymbolTable>(destination_symtab);
2735+
duplicate_SymbolTable(function->m_symtab, function_symtab);
2736+
Vec<ASR::stmt_t*> new_body;
2737+
new_body.reserve(al, function->n_body);
2738+
ASRUtils::ExprStmtDuplicator node_duplicator(al);
2739+
node_duplicator.allow_procedure_calls = true;
2740+
node_duplicator.allow_reshape = false;
2741+
for( size_t i = 0; i < function->n_body; i++ ) {
2742+
node_duplicator.success = true;
2743+
ASR::stmt_t* new_stmt = node_duplicator.duplicate_stmt(function->m_body[i]);
2744+
if( !node_duplicator.success ) {
2745+
return nullptr;
2746+
}
2747+
new_body.push_back(al, new_stmt);
2748+
}
2749+
2750+
Vec<ASR::expr_t*> new_args;
2751+
new_args.reserve(al, function->n_args);
2752+
for( size_t i = 0; i < function->n_args; i++ ) {
2753+
node_duplicator.success = true;
2754+
ASR::expr_t* new_arg = node_duplicator.duplicate_expr(function->m_args[i]);
2755+
if( !node_duplicator.success ) {
2756+
return nullptr;
2757+
}
2758+
new_args.push_back(al, new_arg);
2759+
}
2760+
2761+
node_duplicator.success = true;
2762+
ASR::expr_t* new_return_var = node_duplicator.duplicate_expr(function->m_return_var);
2763+
if( !node_duplicator.success ) {
2764+
return nullptr;
2765+
}
2766+
2767+
ASR::FunctionType_t* function_type = ASRUtils::get_FunctionType(function);
2768+
2769+
Vec<ASR::ttype_t*> new_ttypes;
2770+
new_ttypes.reserve(al, function_type->n_type_params);
2771+
for( size_t i = 0; i < function_type->n_type_params; i++ ) {
2772+
node_duplicator.success = true;
2773+
ASR::ttype_t* new_ttype = node_duplicator.duplicate_ttype(function_type->m_type_params[i]);
2774+
if( !node_duplicator.success ) {
2775+
return nullptr;
2776+
}
2777+
new_ttypes.push_back(al, new_ttype);
2778+
}
2779+
2780+
Vec<ASR::symbol_t*> new_restrictions;
2781+
new_restrictions.reserve(al, function_type->n_restrictions);
2782+
for( size_t i = 0; i < function_type->n_restrictions; i++ ) {
2783+
std::string restriction_name = ASRUtils::symbol_name(function_type->m_restrictions[i]);
2784+
ASR::symbol_t* new_restriction = function_symtab->resolve_symbol(restriction_name);
2785+
if( !new_restriction ) {
2786+
throw LCompilersException("Symbol " + restriction_name + " not found.");
2787+
}
2788+
new_restrictions.push_back(al, new_restriction);
2789+
}
2790+
2791+
return ASR::down_cast<ASR::symbol_t>(make_Function_t_util(al,
2792+
function->base.base.loc, function_symtab, function->m_name,
2793+
function->m_dependencies, function->n_dependencies, new_args.p,
2794+
new_args.size(), new_body.p, new_body.size(), new_return_var,
2795+
function_type->m_abi, function->m_access, function_type->m_deftype,
2796+
function_type->m_bindc_name, function_type->m_elemental, function_type->m_pure,
2797+
function_type->m_module, function_type->m_inline, function_type->m_static,
2798+
new_ttypes.p, new_ttypes.size(), new_restrictions.p, new_restrictions.size(),
2799+
function_type->m_is_restriction, function->m_deterministic,
2800+
function->m_side_effect_free));
2801+
}
2802+
2803+
};
2804+
25842805
class ReplaceReturnWithGotoVisitor: public ASR::BaseStmtReplacer<ReplaceReturnWithGotoVisitor> {
25852806

25862807
private:
@@ -2804,6 +3025,10 @@ static inline bool is_pass_array_by_data_possible(ASR::Function_t* x, std::vecto
28043025
continue;
28053026
}
28063027
typei = ASRUtils::expr_type(x->m_args[i]);
3028+
if( ASR::is_a<ASR::Class_t>(*typei) ||
3029+
ASR::is_a<ASR::FunctionType_t>(*typei) ) {
3030+
continue ;
3031+
}
28073032
int n_dims = ASRUtils::extract_dimensions_from_ttype(typei, dims);
28083033
ASR::Variable_t* argi = ASRUtils::EXPR2VAR(x->m_args[i]);
28093034
if( ASRUtils::is_dimension_empty(dims, n_dims) &&
@@ -2816,34 +3041,6 @@ static inline bool is_pass_array_by_data_possible(ASR::Function_t* x, std::vecto
28163041
return v.size() > 0;
28173042
}
28183043

2819-
inline ASR::asr_t* make_Function_t_util(Allocator& al, const Location& loc,
2820-
SymbolTable* m_symtab, char* m_name, char** m_dependencies, size_t n_dependencies,
2821-
ASR::expr_t** a_args, size_t n_args, ASR::stmt_t** m_body, size_t n_body,
2822-
ASR::expr_t* m_return_var, ASR::abiType m_abi, ASR::accessType m_access,
2823-
ASR::deftypeType m_deftype, char* m_bindc_name, bool m_elemental, bool m_pure,
2824-
bool m_module, bool m_inline, bool m_static, ASR::ttype_t** m_type_params,
2825-
size_t n_type_params, ASR::symbol_t** m_restrictions, size_t n_restrictions,
2826-
bool m_is_restriction, bool m_deterministic, bool m_side_effect_free) {
2827-
Vec<ASR::ttype_t*> arg_types;
2828-
arg_types.reserve(al, n_args);
2829-
for( size_t i = 0; i < n_args; i++ ) {
2830-
arg_types.push_back(al, ASRUtils::expr_type(a_args[i]));
2831-
}
2832-
ASR::ttype_t* return_var_type = nullptr;
2833-
if( m_return_var ) {
2834-
return_var_type = ASRUtils::expr_type(m_return_var);
2835-
}
2836-
ASR::ttype_t* func_type = ASRUtils::TYPE(ASR::make_FunctionType_t(
2837-
al, loc, arg_types.p, arg_types.size(), return_var_type, m_abi,
2838-
m_deftype, m_bindc_name, m_elemental, m_pure, m_module, m_inline,
2839-
m_static, m_type_params, n_type_params, m_restrictions, n_restrictions,
2840-
m_is_restriction));
2841-
return ASR::make_Function_t(
2842-
al, loc, m_symtab, m_name, func_type, m_dependencies, n_dependencies,
2843-
a_args, n_args, m_body, n_body, m_return_var, m_access, m_deterministic,
2844-
m_side_effect_free);
2845-
}
2846-
28473044
static inline ASR::expr_t* get_bound(ASR::expr_t* arr_expr, int dim,
28483045
std::string bound, Allocator& al) {
28493046
ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, arr_expr->base.loc,

0 commit comments

Comments
 (0)