diff --git a/integration_tests/run_tests.py b/integration_tests/run_tests.py index d144ea8d5a..1bea39b346 100755 --- a/integration_tests/run_tests.py +++ b/integration_tests/run_tests.py @@ -29,17 +29,17 @@ "test_math1.py", "test_math_02.py", "test_c_interop_01.py", + "test_generics_01.py", ] # CPython tests only test_cpython = [ - "test_generics_01.py", "test_builtin_bin.py", "test_builtin_hex.py", "test_builtin_oct.py" ] -CUR_DIR = ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__))) +CUR_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__))) def main(): if not os.path.exists(os.path.join(CUR_DIR, 'tmp')): diff --git a/integration_tests/test_generics_01.py b/integration_tests/test_generics_01.py index 5419ad8a5b..442c63e24a 100644 --- a/integration_tests/test_generics_01.py +++ b/integration_tests/test_generics_01.py @@ -23,8 +23,11 @@ def test(a: bool) -> i32: return -10 -assert foo(2) == 4 -assert foo(2, 10) == 20 -assert foo("hello") == "lpython-hello" -assert test(10) == 20 -assert test(False) == -test(True) and test(True) == 10 +def check(): + assert foo(2) == 4 + assert foo(2, 10) == 20 + assert foo("hello") == "lpython-hello" + assert test(10) == 20 + assert test(False) == -test(True) and test(True) == 10 + +check() diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 3dd332ad1e..a3a3ff15b8 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -139,6 +139,7 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab, return mod2; } + template class CommonVisitor : public AST::BaseVisitor { public: @@ -156,10 +157,13 @@ class CommonVisitor : public AST::BaseVisitor { // The main module is stored directly in TranslationUnit, other modules are Modules bool main_module; PythonIntrinsicProcedures intrinsic_procedures; + std::map &ast_overload; CommonVisitor(Allocator &al, SymbolTable *symbol_table, - diag::Diagnostics &diagnostics, bool main_module) - : diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module} { + diag::Diagnostics &diagnostics, bool main_module, + std::map &ast_overload) + : diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module}, + ast_overload{ast_overload} { current_module_dependencies.reserve(al, 4); } @@ -445,7 +449,7 @@ ASR::symbol_t* import_from_module(Allocator &al, ASR::Module_t *m, SymbolTable * throw SemanticError("Only Subroutines, Functions and Variables are currently supported in 'import'", loc); } - // should not reach here + LFORTRAN_ASSERT(false); return nullptr; } @@ -469,11 +473,13 @@ class SymbolTableVisitor : public CommonVisitor { std::map assgn; ASR::symbol_t *current_module_sym; std::vector excluded_from_symtab; + std::map> overload_defs; SymbolTableVisitor(Allocator &al, SymbolTable *symbol_table, - diag::Diagnostics &diagnostics, bool main_module) - : CommonVisitor(al, symbol_table, diagnostics, main_module), is_derived_type{false} {} + diag::Diagnostics &diagnostics, bool main_module, + std::map &ast_overload) + : CommonVisitor(al, symbol_table, diagnostics, main_module, ast_overload), is_derived_type{false} {} ASR::symbol_t* resolve_symbol(const Location &loc, const std::string &sub_name) { @@ -522,7 +528,9 @@ class SymbolTableVisitor : public CommonVisitor { for (size_t i=0; i { Vec args; args.reserve(al, x.m_args.n_args); current_procedure_abi_type = ASR::abiType::Source; - if (x.n_decorator_list == 1) { - AST::expr_t *dec = x.m_decorator_list[0]; - if (AST::is_a(*dec)) { - std::string name = AST::down_cast(dec)->m_id; - if (name == "ccall") { - current_procedure_abi_type = ASR::abiType::BindC; + bool overload = false; + if (x.n_decorator_list > 0) { + for(size_t i=0; i(*dec)) { + std::string name = AST::down_cast(dec)->m_id; + if (name == "ccall") { + current_procedure_abi_type = ASR::abiType::BindC; + } else if (name == "overload") { + overload = true; + } else { + throw SemanticError("Decorator: " + name + " is not supported", + x.base.base.loc); + } + } else { + throw SemanticError("Unsupported Decorator type", + x.base.base.loc); } } } @@ -578,6 +597,18 @@ class SymbolTableVisitor : public CommonVisitor { var))); } std::string sym_name = x.m_name; + if (overload) { + std::string overload_number; + if (overload_defs.find(sym_name) == overload_defs.end()){ + overload_number = "0"; + Vec v; + v.reserve(al, 1); + overload_defs[sym_name] = v; + } else { + overload_number = std::to_string(overload_defs[sym_name].size()); + } + sym_name = "__lpython_overloaded_" + overload_number + "__" + sym_name; + } if (parent_scope->scope.find(sym_name) != parent_scope->scope.end()) { throw SemanticError("Subroutine already defined", tmp->loc); } @@ -631,8 +662,23 @@ class SymbolTableVisitor : public CommonVisitor { s_access, deftype, bindc_name, is_pure, is_module); } - parent_scope->scope[sym_name] = ASR::down_cast(tmp); + ASR::symbol_t * t = ASR::down_cast(tmp); + parent_scope->scope[sym_name] = t; current_scope = parent_scope; + if (overload) { + overload_defs[x.m_name].push_back(al, t); + ast_overload[(int64_t)&x] = t; + } + } + + void create_GenericProcedure(const Location &loc) { + for(auto &p: overload_defs) { + std::string def_name = p.first; + tmp = ASR::make_GenericProcedure_t(al, loc, current_scope, s2c(al, def_name), + p.second.p, p.second.size(), ASR::accessType::Public); + ASR::symbol_t *t = ASR::down_cast(tmp); + current_scope->scope[def_name] = t; + } } void visit_ImportFrom(const AST::ImportFrom_t &x) { @@ -724,9 +770,10 @@ class SymbolTableVisitor : public CommonVisitor { }; Result symbol_table_visitor(Allocator &al, const AST::Module_t &ast, - diag::Diagnostics &diagnostics, bool main_module) + diag::Diagnostics &diagnostics, bool main_module, + std::map &ast_overload) { - SymbolTableVisitor v(al, nullptr, diagnostics, main_module); + SymbolTableVisitor v(al, nullptr, diagnostics, main_module, ast_overload); try { v.visit_Module(ast); } catch (const SemanticError &e) { @@ -748,8 +795,9 @@ class BodyVisitor : public CommonVisitor { ASR::asr_t *asr; Vec *current_body; - BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics, bool main_module) - : CommonVisitor(al, nullptr, diagnostics, main_module), asr{unit} {} + BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics, + bool main_module, std::map &ast_overload) + : CommonVisitor(al, nullptr, diagnostics, main_module, ast_overload), asr{unit} {} // Transforms statements to a list of ASR statements // In addition, it also inserts the following nodes if needed: @@ -817,6 +865,16 @@ class BodyVisitor : public CommonVisitor { } else if (ASR::is_a(*t)) { ASR::Function_t *f = ASR::down_cast(t); handle_fn(x, *f); + } else if (ASR::is_a(*t)) { + ASR::symbol_t *s = ast_overload[(int64_t)&x]; + if (ASR::is_a(*s)) { + handle_fn(x, *ASR::down_cast(s)); + } else if (ASR::is_a(*s)) { + ASR::Function_t *f = ASR::down_cast(s); + handle_fn(x, *f); + } else { + LFORTRAN_ASSERT(false); + } } else { LFORTRAN_ASSERT(false); } @@ -2108,8 +2166,15 @@ class BodyVisitor : public CommonVisitor { x.base.base.loc); } - ASR::symbol_t *s = current_scope->resolve_symbol(call_name); - + ASR::symbol_t *s = current_scope->resolve_symbol(call_name), *s_generic = nullptr; + if (s!=nullptr && s->type == ASR::symbolType::GenericProcedure) { + ASR::GenericProcedure_t *p = ASR::down_cast(s); + int idx = ASRUtils::select_generic_procedure(args, *p, x.base.base.loc, + [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }); + // Create ExternalSymbol for procedures in different modules. + s_generic = s; + s = p->m_procs[idx]; + } if (!s) { if (intrinsic_procedures.is_intrinsic(call_name)) { @@ -2246,10 +2311,10 @@ class BodyVisitor : public CommonVisitor { value = intrinsic_procedures.comptime_eval(call_name, al, x.base.base.loc, args); } tmp = ASR::make_FunctionCall_t(al, x.base.base.loc, stemp, - nullptr, args.p, args.size(), nullptr, 0, a_type, value, nullptr); + s_generic, args.p, args.size(), nullptr, 0, a_type, value, nullptr); } else if(ASR::is_a(*s)) { tmp = ASR::make_SubroutineCall_t(al, x.base.base.loc, stemp, - nullptr, args.p, args.size(), nullptr); + s_generic, args.p, args.size(), nullptr); } else { throw SemanticError("Unsupported call type for " + call_name, x.base.base.loc); @@ -2265,9 +2330,10 @@ class BodyVisitor : public CommonVisitor { Result body_visitor(Allocator &al, const AST::Module_t &ast, diag::Diagnostics &diagnostics, - ASR::asr_t *unit, bool main_module) + ASR::asr_t *unit, bool main_module, + std::map &ast_overload) { - BodyVisitor b(al, unit, diagnostics, main_module); + BodyVisitor b(al, unit, diagnostics, main_module, ast_overload); try { b.visit_Module(ast); } catch (const SemanticError &e) { @@ -2301,10 +2367,13 @@ std::string pickle_python(AST::ast_t &ast, bool colors, bool indent) { Result python_ast_to_asr(Allocator &al, AST::ast_t &ast, diag::Diagnostics &diagnostics, bool main_module) { + std::map ast_overload; + AST::Module_t *ast_m = AST::down_cast2(&ast); ASR::asr_t *unit; - auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module); + auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module, + ast_overload); if (res.ok) { unit = res.result; } else { @@ -2313,7 +2382,8 @@ Result python_ast_to_asr(Allocator &al, ASR::TranslationUnit_t *tu = ASR::down_cast2(unit); LFORTRAN_ASSERT(asr_verify(*tu)); - auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module); + auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module, + ast_overload); if (res2.ok) { tu = res2.result; } else {