@@ -140,13 +140,13 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab,
140
140
}
141
141
142
142
template <typename T>
143
- bool argument_types_match (const Vec<ASR::ttype_t *> &args,
143
+ bool argument_types_match (const Vec<ASR::expr_t *> &args,
144
144
const T &sub) {
145
145
if (args.size () <= sub.n_args ) {
146
146
size_t i;
147
147
for (i = 0 ; i < args.size (); i++) {
148
148
ASR::Variable_t *v = LFortran::ASRUtils::EXPR2VAR (sub.m_args [i]);
149
- ASR::ttype_t *arg1 = args[i];
149
+ ASR::ttype_t *arg1 = ASRUtils::expr_type ( args[i]) ;
150
150
ASR::ttype_t *arg2 = v->m_type ;
151
151
if (!ASRUtils::check_equal_type (arg1, arg2)) {
152
152
return false ;
@@ -164,7 +164,7 @@ bool argument_types_match(const Vec<ASR::ttype_t*> &args,
164
164
}
165
165
}
166
166
167
- bool select_func_subrout (const ASR::symbol_t * proc, const Vec<ASR::ttype_t *> &args,
167
+ bool select_func_subrout (const ASR::symbol_t * proc, const Vec<ASR::expr_t *> &args,
168
168
const Location& loc, const std::function<void (const std::string &, const Location &)> err) {
169
169
bool result = false ;
170
170
if (ASR::is_a<ASR::Subroutine_t>(*proc)) {
@@ -185,8 +185,7 @@ bool select_func_subrout(const ASR::symbol_t* proc, const Vec<ASR::ttype_t*> &ar
185
185
return result;
186
186
}
187
187
188
- std::map<std::string, std::vector<std::string>> overload_definitons;
189
-
188
+ std::map<int , ASR::symbol_t *> ast_overload;
190
189
template <class Derived >
191
190
class CommonVisitor : public AST ::BaseVisitor<Derived> {
192
191
public:
@@ -204,7 +203,7 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
204
203
// The main module is stored directly in TranslationUnit, other modules are Modules
205
204
bool main_module;
206
205
PythonIntrinsicProcedures intrinsic_procedures;
207
- std::map<std::string, std::vector<std::string >> overload_defs;
206
+ std::map<std::string, Vec<ASR:: symbol_t * >> overload_defs;
208
207
209
208
CommonVisitor (Allocator &al, SymbolTable *symbol_table,
210
209
diag::Diagnostics &diagnostics, bool main_module)
@@ -571,7 +570,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
571
570
for (size_t i=0 ; i<x.n_body ; i++) {
572
571
visit_stmt (*x.m_body [i]);
573
572
}
574
-
573
+ if (!overload_defs.empty ()) {
574
+ create_GenericProcedure (x.base .base .loc );
575
+ }
575
576
global_scope = nullptr ;
576
577
tmp = tmp0;
577
578
}
@@ -642,6 +643,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
642
643
std::string overload_number;
643
644
if (overload_defs.find (sym_name) == overload_defs.end ()){
644
645
overload_number = " 0" ;
646
+ Vec<ASR::symbol_t *> v;
647
+ v.reserve (al, 1 );
648
+ overload_defs[sym_name] = v;
645
649
} else {
646
650
overload_number = std::to_string (overload_defs[sym_name].size ());
647
651
}
@@ -700,10 +704,22 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
700
704
s_access, deftype, bindc_name,
701
705
is_pure, is_module);
702
706
}
703
- parent_scope->scope [sym_name] = ASR::down_cast<ASR::symbol_t >(tmp);
707
+ ASR::symbol_t * t = ASR::down_cast<ASR::symbol_t >(tmp);
708
+ parent_scope->scope [sym_name] = t;
704
709
current_scope = parent_scope;
705
710
if (overload) {
706
- overload_defs[x.m_name ].push_back (sym_name);
711
+ overload_defs[x.m_name ].push_back (al, t);
712
+ ast_overload[(int64_t )&x] = t;
713
+ }
714
+ }
715
+
716
+ void create_GenericProcedure (const Location &loc) {
717
+ for (auto &p: overload_defs) {
718
+ std::string def_name = p.first ;
719
+ tmp = ASR::make_GenericProcedure_t (al, loc, current_scope, s2c (al, def_name),
720
+ p.second .p , p.second .size (), ASR::accessType::Public);
721
+ ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t >(tmp);
722
+ current_scope->scope [def_name] = t;
707
723
}
708
724
}
709
725
@@ -801,7 +817,6 @@ Result<ASR::asr_t*> symbol_table_visitor(Allocator &al, const AST::Module_t &ast
801
817
SymbolTableVisitor v (al, nullptr , diagnostics, main_module);
802
818
try {
803
819
v.visit_Module (ast);
804
- overload_definitons = v.overload_defs ;
805
820
} catch (const SemanticError &e) {
806
821
Error error;
807
822
diagnostics.diagnostics .push_back (e.d );
@@ -882,44 +897,24 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
882
897
v.n_body = body.size ();
883
898
}
884
899
885
- ASR::symbol_t * overloaddef_find_helper (std::string func_name, Vec<ASR::ttype_t *> args,
886
- const Location &loc) {
887
- for (auto &t: overload_defs[func_name]) {
888
- SymbolTable *symtab = current_scope;
889
- while (symtab!= nullptr && symtab->scope .find (t) == symtab->scope .end ()) {
890
- symtab = symtab->parent ;
891
- }
892
- LFORTRAN_ASSERT (symtab != nullptr );
893
- ASR::symbol_t *st = symtab->scope [t];
894
- bool ok = select_func_subrout (st, args, loc,
895
- [&](const std::string &msg, const Location &l) { throw SemanticError (msg, l); });
896
- if (ok) {
897
- return st;
898
- }
899
- }
900
- return nullptr ;
901
- }
902
-
903
900
void visit_FunctionDef (const AST::FunctionDef_t &x) {
904
901
SymbolTable *old_scope = current_scope;
905
- ASR::symbol_t *t = nullptr ;
906
- if (overload_defs.find (x.m_name ) != overload_defs.end ()) {
907
- Vec<ASR::ttype_t *> args;
908
- args.reserve (al, x.m_args .n_args );
909
- for (size_t i=0 ; i<x.m_args .n_args ; i++) {
910
- ASR::ttype_t *arg_type = ast_expr_to_asr_type (x.base .base .loc ,
911
- *x.m_args .m_args [i].m_annotation );
912
- args.push_back (al, arg_type);
913
- }
914
- t = overloaddef_find_helper (x.m_name , args, x.base .base .loc );
915
- } else {
916
- t = current_scope->scope [x.m_name ];
917
- }
902
+ ASR::symbol_t *t = t = current_scope->scope [x.m_name ];
918
903
if (ASR::is_a<ASR::Subroutine_t>(*t)) {
919
904
handle_fn (x, *ASR::down_cast<ASR::Subroutine_t>(t));
920
905
} else if (ASR::is_a<ASR::Function_t>(*t)) {
921
906
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
922
907
handle_fn (x, *f);
908
+ } else if (ASR::is_a<ASR::GenericProcedure_t>(*t)) {
909
+ ASR::symbol_t *s = ast_overload[(int64_t )&x];
910
+ if (ASR::is_a<ASR::Subroutine_t>(*s)) {
911
+ handle_fn (x, *ASR::down_cast<ASR::Subroutine_t>(s));
912
+ } else if (ASR::is_a<ASR::Function_t>(*s)) {
913
+ ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(s);
914
+ handle_fn (x, *f);
915
+ } else {
916
+ LFORTRAN_ASSERT (false );
917
+ }
923
918
} else {
924
919
LFORTRAN_ASSERT (false );
925
920
}
@@ -2192,15 +2187,13 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
2192
2187
x.base .base .loc );
2193
2188
}
2194
2189
2195
- ASR::symbol_t *s = current_scope->resolve_symbol (call_name);
2196
-
2197
- if (!s && overload_defs.find (call_name)!=overload_defs.end ()) {
2198
- Vec<ASR::ttype_t *> args_type;
2199
- args_type.reserve (al, x.n_args );
2200
- for (size_t i=0 ; i<x.n_args ; i++) {
2201
- args_type.push_back (al, ASRUtils::expr_type (args[i]));
2202
- }
2203
- s = overloaddef_find_helper (call_name, args_type, x.base .base .loc );
2190
+ ASR::symbol_t *s = current_scope->resolve_symbol (call_name), *s_generic = nullptr ;
2191
+ if (s->type == ASR::symbolType::GenericProcedure){
2192
+ ASR::GenericProcedure_t *p = ASR::down_cast<ASR::GenericProcedure_t>(s);
2193
+ int idx = select_generic_procedure (args, *p, x.base .base .loc );
2194
+ // Create ExternalSymbol for procedures in different modules.
2195
+ s_generic = s;
2196
+ s = p->m_procs [idx];
2204
2197
}
2205
2198
2206
2199
if (!s) {
@@ -2347,6 +2340,27 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
2347
2340
x.base .base .loc );
2348
2341
}
2349
2342
}
2343
+ int select_generic_procedure (const Vec<ASR::expr_t *> &args,
2344
+ const ASR::GenericProcedure_t &p, Location loc) {
2345
+ for (size_t i=0 ; i < p.n_procs ; i++) {
2346
+
2347
+ if ( ASR::is_a<ASR::ClassProcedure_t>(*p.m_procs [i]) ) {
2348
+ ASR::ClassProcedure_t *clss_fn
2349
+ = ASR::down_cast<ASR::ClassProcedure_t>(p.m_procs [i]);
2350
+ const ASR::symbol_t *proc = ASRUtils::symbol_get_past_external (clss_fn->m_proc );
2351
+ if ( select_func_subrout (proc, args, loc,
2352
+ [&](const std::string &msg, const Location &loc) { throw SemanticError (msg, loc); })
2353
+ ){
2354
+ return i;
2355
+ }
2356
+ } else {
2357
+ if ( select_func_subrout (p.m_procs [i], args, loc, [&](const std::string &msg, const Location &loc) { throw SemanticError (msg, loc); }) ) {
2358
+ return i;
2359
+ }
2360
+ }
2361
+ }
2362
+ throw SemanticError (" Arguments do not match for any generic procedure" , loc);
2363
+ }
2350
2364
2351
2365
void visit_ImportFrom (const AST::ImportFrom_t &/* x*/ ) {
2352
2366
// Handled by SymbolTableVisitor already
@@ -2361,7 +2375,6 @@ Result<ASR::TranslationUnit_t*> body_visitor(Allocator &al,
2361
2375
{
2362
2376
BodyVisitor b (al, unit, diagnostics, main_module);
2363
2377
try {
2364
- b.overload_defs = overload_definitons;
2365
2378
b.visit_Module (ast);
2366
2379
} catch (const SemanticError &e) {
2367
2380
Error error;
0 commit comments