Skip to content

Commit 2b24e45

Browse files
committed
use generic procedure
1 parent 06ad19c commit 2b24e45

File tree

1 file changed

+64
-51
lines changed

1 file changed

+64
-51
lines changed

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 64 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,13 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab,
140140
}
141141

142142
template <typename T>
143-
bool argument_types_match(const Vec<ASR::ttype_t*> &args,
143+
bool argument_types_match(const Vec<ASR::expr_t*> &args,
144144
const T &sub) {
145145
if (args.size() <= sub.n_args) {
146146
size_t i;
147147
for (i = 0; i < args.size(); i++) {
148148
ASR::Variable_t *v = LFortran::ASRUtils::EXPR2VAR(sub.m_args[i]);
149-
ASR::ttype_t *arg1 = args[i];
149+
ASR::ttype_t *arg1 = ASRUtils::expr_type(args[i]);
150150
ASR::ttype_t *arg2 = v->m_type;
151151
if (!ASRUtils::check_equal_type(arg1, arg2)) {
152152
return false;
@@ -164,7 +164,7 @@ bool argument_types_match(const Vec<ASR::ttype_t*> &args,
164164
}
165165
}
166166

167-
bool select_func_subrout(const ASR::symbol_t* proc, const Vec<ASR::ttype_t*> &args,
167+
bool select_func_subrout(const ASR::symbol_t* proc, const Vec<ASR::expr_t*> &args,
168168
const Location& loc, const std::function<void (const std::string &, const Location &)> err) {
169169
bool result = false;
170170
if (ASR::is_a<ASR::Subroutine_t>(*proc)) {
@@ -185,8 +185,7 @@ bool select_func_subrout(const ASR::symbol_t* proc, const Vec<ASR::ttype_t*> &ar
185185
return result;
186186
}
187187

188-
std::map<std::string, std::vector<std::string>> overload_definitons;
189-
188+
std::map<int, ASR::symbol_t*> ast_overload;
190189
template <class Derived>
191190
class CommonVisitor : public AST::BaseVisitor<Derived> {
192191
public:
@@ -204,7 +203,7 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
204203
// The main module is stored directly in TranslationUnit, other modules are Modules
205204
bool main_module;
206205
PythonIntrinsicProcedures intrinsic_procedures;
207-
std::map<std::string, std::vector<std::string>> overload_defs;
206+
std::map<std::string, Vec<ASR::symbol_t* >> overload_defs;
208207

209208
CommonVisitor(Allocator &al, SymbolTable *symbol_table,
210209
diag::Diagnostics &diagnostics, bool main_module)
@@ -571,7 +570,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
571570
for (size_t i=0; i<x.n_body; i++) {
572571
visit_stmt(*x.m_body[i]);
573572
}
574-
573+
if (!overload_defs.empty()) {
574+
create_GenericProcedure(x.base.base.loc);
575+
}
575576
global_scope = nullptr;
576577
tmp = tmp0;
577578
}
@@ -642,6 +643,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
642643
std::string overload_number;
643644
if (overload_defs.find(sym_name) == overload_defs.end()){
644645
overload_number = "0";
646+
Vec<ASR::symbol_t *> v;
647+
v.reserve(al, 1);
648+
overload_defs[sym_name] = v;
645649
} else {
646650
overload_number = std::to_string(overload_defs[sym_name].size());
647651
}
@@ -700,10 +704,22 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
700704
s_access, deftype, bindc_name,
701705
is_pure, is_module);
702706
}
703-
parent_scope->scope[sym_name] = ASR::down_cast<ASR::symbol_t>(tmp);
707+
ASR::symbol_t * t = ASR::down_cast<ASR::symbol_t>(tmp);
708+
parent_scope->scope[sym_name] = t;
704709
current_scope = parent_scope;
705710
if (overload) {
706-
overload_defs[x.m_name].push_back(sym_name);
711+
overload_defs[x.m_name].push_back(al, t);
712+
ast_overload[(int64_t)&x] = t;
713+
}
714+
}
715+
716+
void create_GenericProcedure(const Location &loc) {
717+
for(auto &p: overload_defs) {
718+
std::string def_name = p.first;
719+
tmp = ASR::make_GenericProcedure_t(al, loc, current_scope, s2c(al, def_name),
720+
p.second.p, p.second.size(), ASR::accessType::Public);
721+
ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t>(tmp);
722+
current_scope->scope[def_name] = t;
707723
}
708724
}
709725

@@ -801,7 +817,6 @@ Result<ASR::asr_t*> symbol_table_visitor(Allocator &al, const AST::Module_t &ast
801817
SymbolTableVisitor v(al, nullptr, diagnostics, main_module);
802818
try {
803819
v.visit_Module(ast);
804-
overload_definitons = v.overload_defs;
805820
} catch (const SemanticError &e) {
806821
Error error;
807822
diagnostics.diagnostics.push_back(e.d);
@@ -882,44 +897,24 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
882897
v.n_body = body.size();
883898
}
884899

885-
ASR::symbol_t* overloaddef_find_helper(std::string func_name, Vec<ASR::ttype_t*> args,
886-
const Location &loc) {
887-
for(auto &t: overload_defs[func_name]) {
888-
SymbolTable *symtab = current_scope;
889-
while (symtab!= nullptr && symtab->scope.find(t) == symtab->scope.end()) {
890-
symtab = symtab->parent;
891-
}
892-
LFORTRAN_ASSERT(symtab != nullptr);
893-
ASR::symbol_t *st = symtab->scope[t];
894-
bool ok = select_func_subrout(st, args, loc,
895-
[&](const std::string &msg, const Location &l) { throw SemanticError(msg, l); });
896-
if (ok) {
897-
return st;
898-
}
899-
}
900-
return nullptr;
901-
}
902-
903900
void visit_FunctionDef(const AST::FunctionDef_t &x) {
904901
SymbolTable *old_scope = current_scope;
905-
ASR::symbol_t *t = nullptr;
906-
if (overload_defs.find(x.m_name) != overload_defs.end()) {
907-
Vec<ASR::ttype_t *> args;
908-
args.reserve(al, x.m_args.n_args);
909-
for (size_t i=0; i<x.m_args.n_args; i++) {
910-
ASR::ttype_t *arg_type = ast_expr_to_asr_type(x.base.base.loc,
911-
*x.m_args.m_args[i].m_annotation);
912-
args.push_back(al, arg_type);
913-
}
914-
t = overloaddef_find_helper(x.m_name, args, x.base.base.loc);
915-
} else {
916-
t = current_scope->scope[x.m_name];
917-
}
902+
ASR::symbol_t *t = t = current_scope->scope[x.m_name];
918903
if (ASR::is_a<ASR::Subroutine_t>(*t)) {
919904
handle_fn(x, *ASR::down_cast<ASR::Subroutine_t>(t));
920905
} else if (ASR::is_a<ASR::Function_t>(*t)) {
921906
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
922907
handle_fn(x, *f);
908+
} else if (ASR::is_a<ASR::GenericProcedure_t>(*t)) {
909+
ASR::symbol_t *s = ast_overload[(int64_t)&x];
910+
if (ASR::is_a<ASR::Subroutine_t>(*s)) {
911+
handle_fn(x, *ASR::down_cast<ASR::Subroutine_t>(s));
912+
} else if (ASR::is_a<ASR::Function_t>(*s)) {
913+
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(s);
914+
handle_fn(x, *f);
915+
} else {
916+
LFORTRAN_ASSERT(false);
917+
}
923918
} else {
924919
LFORTRAN_ASSERT(false);
925920
}
@@ -2192,15 +2187,13 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
21922187
x.base.base.loc);
21932188
}
21942189

2195-
ASR::symbol_t *s = current_scope->resolve_symbol(call_name);
2196-
2197-
if (!s && overload_defs.find(call_name)!=overload_defs.end()) {
2198-
Vec<ASR::ttype_t*> args_type;
2199-
args_type.reserve(al, x.n_args);
2200-
for(size_t i=0; i<x.n_args; i++) {
2201-
args_type.push_back(al, ASRUtils::expr_type(args[i]));
2202-
}
2203-
s = overloaddef_find_helper(call_name, args_type, x.base.base.loc);
2190+
ASR::symbol_t *s = current_scope->resolve_symbol(call_name), *s_generic = nullptr;
2191+
if (s->type == ASR::symbolType::GenericProcedure){
2192+
ASR::GenericProcedure_t *p = ASR::down_cast<ASR::GenericProcedure_t>(s);
2193+
int idx = select_generic_procedure(args, *p, x.base.base.loc);
2194+
// Create ExternalSymbol for procedures in different modules.
2195+
s_generic = s;
2196+
s = p->m_procs[idx];
22042197
}
22052198

22062199
if (!s) {
@@ -2347,6 +2340,27 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
23472340
x.base.base.loc);
23482341
}
23492342
}
2343+
int select_generic_procedure(const Vec<ASR::expr_t*> &args,
2344+
const ASR::GenericProcedure_t &p, Location loc) {
2345+
for (size_t i=0; i < p.n_procs; i++) {
2346+
2347+
if( ASR::is_a<ASR::ClassProcedure_t>(*p.m_procs[i]) ) {
2348+
ASR::ClassProcedure_t *clss_fn
2349+
= ASR::down_cast<ASR::ClassProcedure_t>(p.m_procs[i]);
2350+
const ASR::symbol_t *proc = ASRUtils::symbol_get_past_external(clss_fn->m_proc);
2351+
if( select_func_subrout(proc, args, loc,
2352+
[&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); })
2353+
){
2354+
return i;
2355+
}
2356+
} else {
2357+
if( select_func_subrout(p.m_procs[i], args, loc, [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }) ) {
2358+
return i;
2359+
}
2360+
}
2361+
}
2362+
throw SemanticError("Arguments do not match for any generic procedure", loc);
2363+
}
23502364

23512365
void visit_ImportFrom(const AST::ImportFrom_t &/*x*/) {
23522366
// Handled by SymbolTableVisitor already
@@ -2361,7 +2375,6 @@ Result<ASR::TranslationUnit_t*> body_visitor(Allocator &al,
23612375
{
23622376
BodyVisitor b(al, unit, diagnostics, main_module);
23632377
try {
2364-
b.overload_defs = overload_definitons;
23652378
b.visit_Module(ast);
23662379
} catch (const SemanticError &e) {
23672380
Error error;

0 commit comments

Comments
 (0)