@@ -717,10 +717,24 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
717
717
stemp = symtab->get_symbol (local_sym);
718
718
}
719
719
}
720
- if (ASR::is_a<ASR::Function_t>(*s) &&
721
- ASR::down_cast<ASR::Function_t>(s)->m_return_var != nullptr ) {
720
+ if (ASR::is_a<ASR::Function_t>(*s)) {
722
721
ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
723
- if (func->n_type_params == 0 ) {
722
+ if (func->n_type_params > 0 ) {
723
+ std::map<std::string, ASR::ttype_t *> subs;
724
+ for (size_t i=0 ; i<args.size (); i++) {
725
+ ASR::ttype_t *param_type = ASRUtils::expr_type (func->m_args [i]);
726
+ ASR::ttype_t *arg_type = ASRUtils::expr_type (args[i].m_value );
727
+ subs = check_type_substitution (subs, param_type, arg_type, loc);
728
+ }
729
+
730
+ ASR::symbol_t *t = get_generic_function (subs, *func);
731
+ std::string new_call_name = call_name;
732
+ if (ASR::is_a<ASR::Function_t>(*t)) {
733
+ new_call_name = (ASR::down_cast<ASR::Function_t>(t))->m_name ;
734
+ }
735
+ return make_call_helper (al, t, current_scope, args, new_call_name, loc);
736
+ }
737
+ if (ASR::down_cast<ASR::Function_t>(s)->m_return_var != nullptr ) {
724
738
ASR::ttype_t *a_type = nullptr ;
725
739
if ( func->m_elemental && args.size () == 1 &&
726
740
ASRUtils::is_array (ASRUtils::expr_type (args[0 ].m_value )) ) {
@@ -766,39 +780,25 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
766
780
return func_call_asr;
767
781
}
768
782
} else {
769
- std::map<std::string, ASR::ttype_t *> subs;
770
- for (size_t i=0 ; i<args.size (); i++) {
771
- ASR::ttype_t *param_type = ASRUtils::expr_type (func->m_args [i]);
772
- ASR::ttype_t *arg_type = ASRUtils::expr_type (args[i].m_value );
773
- subs = check_type_substitution (subs, param_type, arg_type, loc);
774
- }
775
-
776
- ASR::symbol_t *t = get_generic_function (subs, *func);
777
- std::string new_call_name = call_name;
778
- if (ASR::is_a<ASR::Function_t>(*t)) {
779
- new_call_name = (ASR::down_cast<ASR::Function_t>(t))->m_name ;
783
+ ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
784
+ if (args.size () != func->n_args ) {
785
+ std::string fnd = std::to_string (args.size ());
786
+ std::string org = std::to_string (func->n_args );
787
+ diag.add (diag::Diagnostic (
788
+ " Number of arguments does not match in the function call" ,
789
+ diag::Level::Error, diag::Stage::Semantic, {
790
+ diag::Label (" (found: '" + fnd + " ', expected: '" + org + " ')" ,
791
+ {loc})
792
+ })
793
+ );
794
+ throw SemanticAbort ();
780
795
}
781
- return make_call_helper (al, t, current_scope, args, new_call_name, loc);
782
- }
783
- } else if (ASR::is_a<ASR::Function_t>(*s)) {
784
- ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
785
- if (args.size () != func->n_args ) {
786
- std::string fnd = std::to_string (args.size ());
787
- std::string org = std::to_string (func->n_args );
788
- diag.add (diag::Diagnostic (
789
- " Number of arguments does not match in the function call" ,
790
- diag::Level::Error, diag::Stage::Semantic, {
791
- diag::Label (" (found: '" + fnd + " ', expected: '" + org + " ')" ,
792
- {loc})
793
- })
794
- );
795
- throw SemanticAbort ();
796
+ Vec<ASR::call_arg_t > args_new;
797
+ args_new.reserve (al, func->n_args );
798
+ visit_expr_list_with_cast (func->m_args , func->n_args , args_new, args);
799
+ return ASR::make_SubroutineCall_t (al, loc, stemp,
800
+ s_generic, args_new.p , args_new.size (), nullptr );
796
801
}
797
- Vec<ASR::call_arg_t > args_new;
798
- args_new.reserve (al, func->n_args );
799
- visit_expr_list_with_cast (func->m_args , func->n_args , args_new, args);
800
- return ASR::make_SubroutineCall_t (al, loc, stemp,
801
- s_generic, args_new.p , args_new.size (), nullptr );
802
802
} else if (ASR::is_a<ASR::DerivedType_t>(*s)) {
803
803
Vec<ASR::expr_t *> args_new;
804
804
args_new.reserve (al, args.size ());
@@ -891,8 +891,9 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
891
891
new_function_num = 0 ;
892
892
}
893
893
generic_func_nums[func_name] = new_function_num + 1 ;
894
- generic_func_subs[" __lpython_generic_" + func_name + " _" + std::to_string (new_function_num)] = subs;
895
- t = pass_instantiate_generic_function (al, subs, current_scope, new_function_num, func);
894
+ std::string new_func_name = " __lpython_generic_" + func_name + " _" + std::to_string (new_function_num);
895
+ generic_func_subs[new_func_name] = subs;
896
+ t = pass_instantiate_generic_function (al, subs, current_scope, new_func_name, func);
896
897
return t;
897
898
}
898
899
@@ -2273,7 +2274,8 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
2273
2274
current_procedure_abi_type = ASR::abiType::Source;
2274
2275
bool current_procedure_interface = false ;
2275
2276
bool overload = false ;
2276
- std::set<std::string> ps;
2277
+ Vec<ASR::ttype_t *> tps;
2278
+ tps.reserve (al, x.m_args .n_args );
2277
2279
bool vectorize = false ;
2278
2280
if (x.n_decorator_list > 0 ) {
2279
2281
for (size_t i=0 ; i<x.n_decorator_list ; i++) {
@@ -2310,8 +2312,25 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
2310
2312
ASR::ttype_t *arg_type = ast_expr_to_asr_type (x.base .base .loc , *x.m_args .m_args [i].m_annotation );
2311
2313
// Set the function as generic if an argument is typed with a type parameter
2312
2314
if (ASRUtils::is_generic (*arg_type)) {
2313
- std::string param_name = ASRUtils::get_parameter_name (arg_type);
2314
- ps.insert (param_name);
2315
+ ASR::ttype_t *new_tt = ASRUtils::duplicate_type_without_dims (al, ASRUtils::get_type_parameter (arg_type));
2316
+ size_t current_size = tps.size ();
2317
+ if (current_size == 0 ) {
2318
+ tps.push_back (al, new_tt);
2319
+ } else {
2320
+ bool not_found = true ;
2321
+ for (size_t i = 0 ; i < current_size; i++) {
2322
+ ASR::TypeParameter_t *added_tp = ASR::down_cast<ASR::TypeParameter_t>(tps.p [i]);
2323
+ std::string new_param = ASR::down_cast<ASR::TypeParameter_t>(new_tt)->m_param ;
2324
+ std::string added_param = added_tp->m_param ;
2325
+ if (added_param.compare (new_param) == 0 ) {
2326
+ not_found = false ;
2327
+ break ;
2328
+ }
2329
+ }
2330
+ if (not_found) {
2331
+ tps.push_back (al, new_tt);
2332
+ }
2333
+ }
2315
2334
}
2316
2335
2317
2336
std::string arg_s = arg;
@@ -2377,43 +2396,19 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
2377
2396
ASR::down_cast<ASR::symbol_t >(return_var));
2378
2397
ASR::asr_t *return_var_ref = ASR::make_Var_t (al, x.base .base .loc ,
2379
2398
current_scope->get_symbol (return_var_name));
2380
- if (ps.size () > 0 ) {
2381
- Vec<ASR::ttype_t *> type_params;
2382
- type_params.reserve (al, ps.size ());
2383
- for (auto &p: ps) {
2384
- std::string param = p;
2385
- ASR::ttype_t *type_p = ASRUtils::TYPE (ASR::make_TypeParameter_t (al,
2386
- x.base .base .loc , s2c (al, p), nullptr , 0 ));
2387
- type_params.push_back (al, type_p);
2388
- }
2389
- tmp = ASR::make_Function_t (
2390
- al, x.base .base .loc ,
2391
- /* a_symtab */ current_scope,
2392
- /* a_name */ s2c (al, sym_name),
2393
- /* a_args */ args.p ,
2394
- /* n_args */ args.size (),
2395
- /* a_type_params */ type_params.p ,
2396
- /* n_type_params */ type_params.size (),
2397
- /* a_body */ nullptr ,
2398
- /* n_body */ 0 ,
2399
- /* a_return_var */ ASRUtils::EXPR (return_var_ref),
2400
- current_procedure_abi_type,
2401
- s_access, deftype, bindc_name, vectorize, false , false );
2402
- } else {
2403
- tmp = ASR::make_Function_t (
2404
- al, x.base .base .loc ,
2405
- /* a_symtab */ current_scope,
2406
- /* a_name */ s2c (al, sym_name),
2407
- /* a_args */ args.p ,
2408
- /* n_args */ args.size (),
2409
- /* a_type_params */ nullptr ,
2410
- /* n_type_params */ 0 ,
2411
- /* a_body */ nullptr ,
2412
- /* n_body */ 0 ,
2413
- /* a_return_var */ ASRUtils::EXPR (return_var_ref),
2414
- current_procedure_abi_type,
2415
- s_access, deftype, bindc_name, vectorize, false , false );
2416
- }
2399
+ tmp = ASR::make_Function_t (
2400
+ al, x.base .base .loc ,
2401
+ /* a_symtab */ current_scope,
2402
+ /* a_name */ s2c (al, sym_name),
2403
+ /* a_args */ args.p ,
2404
+ /* n_args */ args.size (),
2405
+ /* a_type_params */ tps.p ,
2406
+ /* n_type_params */ tps.size (),
2407
+ /* a_body */ nullptr ,
2408
+ /* n_body */ 0 ,
2409
+ /* a_return_var */ ASRUtils::EXPR (return_var_ref),
2410
+ current_procedure_abi_type,
2411
+ s_access, deftype, bindc_name, vectorize, false , false );
2417
2412
} else {
2418
2413
throw SemanticError (" Return variable must be an identifier (Name AST node) or an array (Subscript AST node)" ,
2419
2414
x.m_returns ->base .loc );
@@ -2426,7 +2421,8 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
2426
2421
/* a_name */ s2c (al, sym_name),
2427
2422
/* a_args */ args.p ,
2428
2423
/* n_args */ args.size (),
2429
- nullptr , 0 ,
2424
+ /* a_type_params */ tps.p ,
2425
+ /* n_type_params */ tps.size (),
2430
2426
/* a_body */ nullptr ,
2431
2427
/* n_body */ 0 ,
2432
2428
nullptr ,
0 commit comments