@@ -139,6 +139,7 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab,
139
139
return mod2;
140
140
}
141
141
142
+
142
143
template <class Derived >
143
144
class CommonVisitor : public AST ::BaseVisitor<Derived> {
144
145
public:
@@ -156,10 +157,13 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
156
157
// The main module is stored directly in TranslationUnit, other modules are Modules
157
158
bool main_module;
158
159
PythonIntrinsicProcedures intrinsic_procedures;
160
+ std::map<int , ASR::symbol_t *> &ast_overload;
159
161
160
162
CommonVisitor (Allocator &al, SymbolTable *symbol_table,
161
- diag::Diagnostics &diagnostics, bool main_module)
162
- : diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module} {
163
+ diag::Diagnostics &diagnostics, bool main_module,
164
+ std::map<int , ASR::symbol_t *> &ast_overload)
165
+ : diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module},
166
+ ast_overload{ast_overload} {
163
167
current_module_dependencies.reserve (al, 4 );
164
168
}
165
169
@@ -445,7 +449,7 @@ ASR::symbol_t* import_from_module(Allocator &al, ASR::Module_t *m, SymbolTable *
445
449
throw SemanticError (" Only Subroutines, Functions and Variables are currently supported in 'import'" ,
446
450
loc);
447
451
}
448
- // should not reach here
452
+ LFORTRAN_ASSERT ( false );
449
453
return nullptr ;
450
454
}
451
455
@@ -469,11 +473,13 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
469
473
std::map<SymbolTable*, ASR::accessType> assgn;
470
474
ASR::symbol_t *current_module_sym;
471
475
std::vector<std::string> excluded_from_symtab;
476
+ std::map<std::string, Vec<ASR::symbol_t * >> overload_defs;
472
477
473
478
474
479
SymbolTableVisitor (Allocator &al, SymbolTable *symbol_table,
475
- diag::Diagnostics &diagnostics, bool main_module)
476
- : CommonVisitor(al, symbol_table, diagnostics, main_module), is_derived_type{false } {}
480
+ diag::Diagnostics &diagnostics, bool main_module,
481
+ std::map<int , ASR::symbol_t *> &ast_overload)
482
+ : CommonVisitor(al, symbol_table, diagnostics, main_module, ast_overload), is_derived_type{false } {}
477
483
478
484
479
485
ASR::symbol_t * resolve_symbol (const Location &loc, const std::string &sub_name) {
@@ -522,7 +528,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
522
528
for (size_t i=0 ; i<x.n_body ; i++) {
523
529
visit_stmt (*x.m_body [i]);
524
530
}
525
-
531
+ if (!overload_defs.empty ()) {
532
+ create_GenericProcedure (x.base .base .loc );
533
+ }
526
534
global_scope = nullptr ;
527
535
tmp = tmp0;
528
536
}
@@ -534,12 +542,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
534
542
Vec<ASR::expr_t *> args;
535
543
args.reserve (al, x.m_args .n_args );
536
544
current_procedure_abi_type = ASR::abiType::Source;
537
- if (x.n_decorator_list == 1 ) {
538
- AST::expr_t *dec = x.m_decorator_list [0 ];
539
- if (AST::is_a<AST::Name_t>(*dec)) {
540
- std::string name = AST::down_cast<AST::Name_t>(dec)->m_id ;
541
- if (name == " ccall" ) {
542
- current_procedure_abi_type = ASR::abiType::BindC;
545
+ bool overload = false ;
546
+ if (x.n_decorator_list > 0 ) {
547
+ for (size_t i=0 ; i<x.n_decorator_list ; i++) {
548
+ AST::expr_t *dec = x.m_decorator_list [i];
549
+ if (AST::is_a<AST::Name_t>(*dec)) {
550
+ std::string name = AST::down_cast<AST::Name_t>(dec)->m_id ;
551
+ if (name == " ccall" ) {
552
+ current_procedure_abi_type = ASR::abiType::BindC;
553
+ } else if (name == " overload" ) {
554
+ overload = true ;
555
+ } else {
556
+ throw SemanticError (" Decorator: " + name + " is not supported" ,
557
+ x.base .base .loc );
558
+ }
559
+ } else {
560
+ throw SemanticError (" Unsupported Decorator type" ,
561
+ x.base .base .loc );
543
562
}
544
563
}
545
564
}
@@ -578,6 +597,18 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
578
597
var)));
579
598
}
580
599
std::string sym_name = x.m_name ;
600
+ if (overload) {
601
+ std::string overload_number;
602
+ if (overload_defs.find (sym_name) == overload_defs.end ()){
603
+ overload_number = " 0" ;
604
+ Vec<ASR::symbol_t *> v;
605
+ v.reserve (al, 1 );
606
+ overload_defs[sym_name] = v;
607
+ } else {
608
+ overload_number = std::to_string (overload_defs[sym_name].size ());
609
+ }
610
+ sym_name = " __lpython_overloaded_" + overload_number + " __" + sym_name;
611
+ }
581
612
if (parent_scope->scope .find (sym_name) != parent_scope->scope .end ()) {
582
613
throw SemanticError (" Subroutine already defined" , tmp->loc );
583
614
}
@@ -631,8 +662,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
631
662
s_access, deftype, bindc_name,
632
663
is_pure, is_module);
633
664
}
634
- parent_scope->scope [sym_name] = ASR::down_cast<ASR::symbol_t >(tmp);
665
+ ASR::symbol_t * t = ASR::down_cast<ASR::symbol_t >(tmp);
666
+ parent_scope->scope [sym_name] = t;
635
667
current_scope = parent_scope;
668
+ if (overload) {
669
+ overload_defs[x.m_name ].push_back (al, t);
670
+ ast_overload[(int64_t )&x] = t;
671
+ }
672
+ }
673
+
674
+ void create_GenericProcedure (const Location &loc) {
675
+ for (auto &p: overload_defs) {
676
+ std::string def_name = p.first ;
677
+ tmp = ASR::make_GenericProcedure_t (al, loc, current_scope, s2c (al, def_name),
678
+ p.second .p , p.second .size (), ASR::accessType::Public);
679
+ ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t >(tmp);
680
+ current_scope->scope [def_name] = t;
681
+ }
636
682
}
637
683
638
684
void visit_ImportFrom (const AST::ImportFrom_t &x) {
@@ -724,9 +770,10 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
724
770
};
725
771
726
772
Result<ASR::asr_t *> symbol_table_visitor (Allocator &al, const AST::Module_t &ast,
727
- diag::Diagnostics &diagnostics, bool main_module)
773
+ diag::Diagnostics &diagnostics, bool main_module,
774
+ std::map<int , ASR::symbol_t *> &ast_overload)
728
775
{
729
- SymbolTableVisitor v (al, nullptr , diagnostics, main_module);
776
+ SymbolTableVisitor v (al, nullptr , diagnostics, main_module, ast_overload );
730
777
try {
731
778
v.visit_Module (ast);
732
779
} catch (const SemanticError &e) {
@@ -748,8 +795,9 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
748
795
ASR::asr_t *asr;
749
796
Vec<ASR::stmt_t *> *current_body;
750
797
751
- BodyVisitor (Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics, bool main_module)
752
- : CommonVisitor(al, nullptr , diagnostics, main_module), asr{unit} {}
798
+ BodyVisitor (Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics,
799
+ bool main_module, std::map<int , ASR::symbol_t *> &ast_overload)
800
+ : CommonVisitor(al, nullptr , diagnostics, main_module, ast_overload), asr{unit} {}
753
801
754
802
// Transforms statements to a list of ASR statements
755
803
// In addition, it also inserts the following nodes if needed:
@@ -817,6 +865,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
817
865
} else if (ASR::is_a<ASR::Function_t>(*t)) {
818
866
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
819
867
handle_fn (x, *f);
868
+ } else if (ASR::is_a<ASR::GenericProcedure_t>(*t)) {
869
+ ASR::symbol_t *s = ast_overload[(int64_t )&x];
870
+ if (ASR::is_a<ASR::Subroutine_t>(*s)) {
871
+ handle_fn (x, *ASR::down_cast<ASR::Subroutine_t>(s));
872
+ } else if (ASR::is_a<ASR::Function_t>(*s)) {
873
+ ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(s);
874
+ handle_fn (x, *f);
875
+ } else {
876
+ LFORTRAN_ASSERT (false );
877
+ }
820
878
} else {
821
879
LFORTRAN_ASSERT (false );
822
880
}
@@ -2108,8 +2166,15 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
2108
2166
x.base .base .loc );
2109
2167
}
2110
2168
2111
- ASR::symbol_t *s = current_scope->resolve_symbol (call_name);
2112
-
2169
+ ASR::symbol_t *s = current_scope->resolve_symbol (call_name), *s_generic = nullptr ;
2170
+ if (s!=nullptr && s->type == ASR::symbolType::GenericProcedure) {
2171
+ ASR::GenericProcedure_t *p = ASR::down_cast<ASR::GenericProcedure_t>(s);
2172
+ int idx = ASRUtils::select_generic_procedure (args, *p, x.base .base .loc ,
2173
+ [&](const std::string &msg, const Location &loc) { throw SemanticError (msg, loc); });
2174
+ // Create ExternalSymbol for procedures in different modules.
2175
+ s_generic = s;
2176
+ s = p->m_procs [idx];
2177
+ }
2113
2178
2114
2179
if (!s) {
2115
2180
if (intrinsic_procedures.is_intrinsic (call_name)) {
@@ -2246,10 +2311,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
2246
2311
value = intrinsic_procedures.comptime_eval (call_name, al, x.base .base .loc , args);
2247
2312
}
2248
2313
tmp = ASR::make_FunctionCall_t (al, x.base .base .loc , stemp,
2249
- nullptr , args.p , args.size (), nullptr , 0 , a_type, value, nullptr );
2314
+ s_generic , args.p , args.size (), nullptr , 0 , a_type, value, nullptr );
2250
2315
} else if (ASR::is_a<ASR::Subroutine_t>(*s)) {
2251
2316
tmp = ASR::make_SubroutineCall_t (al, x.base .base .loc , stemp,
2252
- nullptr , args.p , args.size (), nullptr );
2317
+ s_generic , args.p , args.size (), nullptr );
2253
2318
} else {
2254
2319
throw SemanticError (" Unsupported call type for " + call_name,
2255
2320
x.base .base .loc );
@@ -2265,9 +2330,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
2265
2330
Result<ASR::TranslationUnit_t*> body_visitor (Allocator &al,
2266
2331
const AST::Module_t &ast,
2267
2332
diag::Diagnostics &diagnostics,
2268
- ASR::asr_t *unit, bool main_module)
2333
+ ASR::asr_t *unit, bool main_module,
2334
+ std::map<int , ASR::symbol_t *> &ast_overload)
2269
2335
{
2270
- BodyVisitor b (al, unit, diagnostics, main_module);
2336
+ BodyVisitor b (al, unit, diagnostics, main_module, ast_overload );
2271
2337
try {
2272
2338
b.visit_Module (ast);
2273
2339
} catch (const SemanticError &e) {
@@ -2301,10 +2367,13 @@ std::string pickle_python(AST::ast_t &ast, bool colors, bool indent) {
2301
2367
Result<ASR::TranslationUnit_t*> python_ast_to_asr (Allocator &al,
2302
2368
AST::ast_t &ast, diag::Diagnostics &diagnostics, bool main_module)
2303
2369
{
2370
+ std::map<int , ASR::symbol_t *> ast_overload;
2371
+
2304
2372
AST::Module_t *ast_m = AST::down_cast2<AST::Module_t>(&ast);
2305
2373
2306
2374
ASR::asr_t *unit;
2307
- auto res = symbol_table_visitor (al, *ast_m, diagnostics, main_module);
2375
+ auto res = symbol_table_visitor (al, *ast_m, diagnostics, main_module,
2376
+ ast_overload);
2308
2377
if (res.ok ) {
2309
2378
unit = res.result ;
2310
2379
} else {
@@ -2313,7 +2382,8 @@ Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al,
2313
2382
ASR::TranslationUnit_t *tu = ASR::down_cast2<ASR::TranslationUnit_t>(unit);
2314
2383
LFORTRAN_ASSERT (asr_verify (*tu));
2315
2384
2316
- auto res2 = body_visitor (al, *ast_m, diagnostics, unit, main_module);
2385
+ auto res2 = body_visitor (al, *ast_m, diagnostics, unit, main_module,
2386
+ ast_overload);
2317
2387
if (res2.ok ) {
2318
2388
tu = res2.result ;
2319
2389
} else {
0 commit comments