Skip to content

Add overload support in lpython #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 9, 2022
Merged
4 changes: 2 additions & 2 deletions integration_tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@
"test_math1.py",
"test_math_02.py",
"test_c_interop_01.py",
"test_generics_01.py",
]

# CPython tests only
test_cpython = [
"test_generics_01.py",
"test_builtin_bin.py",
"test_builtin_hex.py",
"test_builtin_oct.py"
]

CUR_DIR = ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))
CUR_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))

def main():
if not os.path.exists(os.path.join(CUR_DIR, 'tmp')):
Expand Down
13 changes: 8 additions & 5 deletions integration_tests/test_generics_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ def test(a: bool) -> i32:
return -10


assert foo(2) == 4
assert foo(2, 10) == 20
assert foo("hello") == "lpython-hello"
assert test(10) == 20
assert test(False) == -test(True) and test(True) == 10
def check():
assert foo(2) == 4
assert foo(2, 10) == 20
assert foo("hello") == "lpython-hello"
assert test(10) == 20
assert test(False) == -test(True) and test(True) == 10

check()
120 changes: 95 additions & 25 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab,
return mod2;
}


template <class Derived>
class CommonVisitor : public AST::BaseVisitor<Derived> {
public:
Expand All @@ -156,10 +157,13 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
// The main module is stored directly in TranslationUnit, other modules are Modules
bool main_module;
PythonIntrinsicProcedures intrinsic_procedures;
std::map<int, ASR::symbol_t*> &ast_overload;

CommonVisitor(Allocator &al, SymbolTable *symbol_table,
diag::Diagnostics &diagnostics, bool main_module)
: diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module} {
diag::Diagnostics &diagnostics, bool main_module,
std::map<int, ASR::symbol_t*> &ast_overload)
: diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module},
ast_overload{ast_overload} {
current_module_dependencies.reserve(al, 4);
}

Expand Down Expand Up @@ -445,7 +449,7 @@ ASR::symbol_t* import_from_module(Allocator &al, ASR::Module_t *m, SymbolTable *
throw SemanticError("Only Subroutines, Functions and Variables are currently supported in 'import'",
loc);
}
// should not reach here
LFORTRAN_ASSERT(false);
return nullptr;
}

Expand All @@ -469,11 +473,13 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
std::map<SymbolTable*, ASR::accessType> assgn;
ASR::symbol_t *current_module_sym;
std::vector<std::string> excluded_from_symtab;
std::map<std::string, Vec<ASR::symbol_t* >> overload_defs;


SymbolTableVisitor(Allocator &al, SymbolTable *symbol_table,
diag::Diagnostics &diagnostics, bool main_module)
: CommonVisitor(al, symbol_table, diagnostics, main_module), is_derived_type{false} {}
diag::Diagnostics &diagnostics, bool main_module,
std::map<int, ASR::symbol_t*> &ast_overload)
: CommonVisitor(al, symbol_table, diagnostics, main_module, ast_overload), is_derived_type{false} {}


ASR::symbol_t* resolve_symbol(const Location &loc, const std::string &sub_name) {
Expand Down Expand Up @@ -522,7 +528,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
for (size_t i=0; i<x.n_body; i++) {
visit_stmt(*x.m_body[i]);
}

if (!overload_defs.empty()) {
create_GenericProcedure(x.base.base.loc);
}
global_scope = nullptr;
tmp = tmp0;
}
Expand All @@ -534,12 +542,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
Vec<ASR::expr_t*> args;
args.reserve(al, x.m_args.n_args);
current_procedure_abi_type = ASR::abiType::Source;
if (x.n_decorator_list == 1) {
AST::expr_t *dec = x.m_decorator_list[0];
if (AST::is_a<AST::Name_t>(*dec)) {
std::string name = AST::down_cast<AST::Name_t>(dec)->m_id;
if (name == "ccall") {
current_procedure_abi_type = ASR::abiType::BindC;
bool overload = false;
if (x.n_decorator_list > 0) {
for(size_t i=0; i<x.n_decorator_list; i++) {
AST::expr_t *dec = x.m_decorator_list[i];
if (AST::is_a<AST::Name_t>(*dec)) {
std::string name = AST::down_cast<AST::Name_t>(dec)->m_id;
if (name == "ccall") {
current_procedure_abi_type = ASR::abiType::BindC;
} else if (name == "overload") {
overload = true;
} else {
throw SemanticError("Decorator: " + name + " is not supported",
x.base.base.loc);
}
} else {
throw SemanticError("Unsupported Decorator type",
x.base.base.loc);
}
}
}
Expand Down Expand Up @@ -578,6 +597,18 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
var)));
}
std::string sym_name = x.m_name;
if (overload) {
std::string overload_number;
if (overload_defs.find(sym_name) == overload_defs.end()){
overload_number = "0";
Vec<ASR::symbol_t *> v;
v.reserve(al, 1);
overload_defs[sym_name] = v;
} else {
overload_number = std::to_string(overload_defs[sym_name].size());
}
sym_name = "__lpython_overloaded_" + overload_number + "__" + sym_name;
}
if (parent_scope->scope.find(sym_name) != parent_scope->scope.end()) {
throw SemanticError("Subroutine already defined", tmp->loc);
}
Expand Down Expand Up @@ -631,8 +662,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
s_access, deftype, bindc_name,
is_pure, is_module);
}
parent_scope->scope[sym_name] = ASR::down_cast<ASR::symbol_t>(tmp);
ASR::symbol_t * t = ASR::down_cast<ASR::symbol_t>(tmp);
parent_scope->scope[sym_name] = t;
current_scope = parent_scope;
if (overload) {
overload_defs[x.m_name].push_back(al, t);
ast_overload[(int64_t)&x] = t;
}
}

void create_GenericProcedure(const Location &loc) {
for(auto &p: overload_defs) {
std::string def_name = p.first;
tmp = ASR::make_GenericProcedure_t(al, loc, current_scope, s2c(al, def_name),
p.second.p, p.second.size(), ASR::accessType::Public);
ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t>(tmp);
current_scope->scope[def_name] = t;
}
}

void visit_ImportFrom(const AST::ImportFrom_t &x) {
Expand Down Expand Up @@ -724,9 +770,10 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
};

Result<ASR::asr_t*> symbol_table_visitor(Allocator &al, const AST::Module_t &ast,
diag::Diagnostics &diagnostics, bool main_module)
diag::Diagnostics &diagnostics, bool main_module,
std::map<int, ASR::symbol_t*> &ast_overload)
{
SymbolTableVisitor v(al, nullptr, diagnostics, main_module);
SymbolTableVisitor v(al, nullptr, diagnostics, main_module, ast_overload);
try {
v.visit_Module(ast);
} catch (const SemanticError &e) {
Expand All @@ -748,8 +795,9 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
ASR::asr_t *asr;
Vec<ASR::stmt_t*> *current_body;

BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics, bool main_module)
: CommonVisitor(al, nullptr, diagnostics, main_module), asr{unit} {}
BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics,
bool main_module, std::map<int, ASR::symbol_t*> &ast_overload)
: CommonVisitor(al, nullptr, diagnostics, main_module, ast_overload), asr{unit} {}

// Transforms statements to a list of ASR statements
// In addition, it also inserts the following nodes if needed:
Expand Down Expand Up @@ -817,6 +865,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
} else if (ASR::is_a<ASR::Function_t>(*t)) {
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
handle_fn(x, *f);
} else if (ASR::is_a<ASR::GenericProcedure_t>(*t)) {
ASR::symbol_t *s = ast_overload[(int64_t)&x];
if (ASR::is_a<ASR::Subroutine_t>(*s)) {
handle_fn(x, *ASR::down_cast<ASR::Subroutine_t>(s));
} else if (ASR::is_a<ASR::Function_t>(*s)) {
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(s);
handle_fn(x, *f);
} else {
LFORTRAN_ASSERT(false);
}
} else {
LFORTRAN_ASSERT(false);
}
Expand Down Expand Up @@ -2108,8 +2166,15 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
x.base.base.loc);
}

ASR::symbol_t *s = current_scope->resolve_symbol(call_name);

ASR::symbol_t *s = current_scope->resolve_symbol(call_name), *s_generic = nullptr;
if (s!=nullptr && s->type == ASR::symbolType::GenericProcedure) {
ASR::GenericProcedure_t *p = ASR::down_cast<ASR::GenericProcedure_t>(s);
int idx = ASRUtils::select_generic_procedure(args, *p, x.base.base.loc,
[&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); });
// Create ExternalSymbol for procedures in different modules.
s_generic = s;
s = p->m_procs[idx];
}

if (!s) {
if (intrinsic_procedures.is_intrinsic(call_name)) {
Expand Down Expand Up @@ -2246,10 +2311,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
value = intrinsic_procedures.comptime_eval(call_name, al, x.base.base.loc, args);
}
tmp = ASR::make_FunctionCall_t(al, x.base.base.loc, stemp,
nullptr, args.p, args.size(), nullptr, 0, a_type, value, nullptr);
s_generic, args.p, args.size(), nullptr, 0, a_type, value, nullptr);
} else if(ASR::is_a<ASR::Subroutine_t>(*s)) {
tmp = ASR::make_SubroutineCall_t(al, x.base.base.loc, stemp,
nullptr, args.p, args.size(), nullptr);
s_generic, args.p, args.size(), nullptr);
} else {
throw SemanticError("Unsupported call type for " + call_name,
x.base.base.loc);
Expand All @@ -2265,9 +2330,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
Result<ASR::TranslationUnit_t*> body_visitor(Allocator &al,
const AST::Module_t &ast,
diag::Diagnostics &diagnostics,
ASR::asr_t *unit, bool main_module)
ASR::asr_t *unit, bool main_module,
std::map<int, ASR::symbol_t*> &ast_overload)
{
BodyVisitor b(al, unit, diagnostics, main_module);
BodyVisitor b(al, unit, diagnostics, main_module, ast_overload);
try {
b.visit_Module(ast);
} catch (const SemanticError &e) {
Expand Down Expand Up @@ -2301,10 +2367,13 @@ std::string pickle_python(AST::ast_t &ast, bool colors, bool indent) {
Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al,
AST::ast_t &ast, diag::Diagnostics &diagnostics, bool main_module)
{
std::map<int, ASR::symbol_t*> ast_overload;

AST::Module_t *ast_m = AST::down_cast2<AST::Module_t>(&ast);

ASR::asr_t *unit;
auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module);
auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module,
ast_overload);
if (res.ok) {
unit = res.result;
} else {
Expand All @@ -2313,7 +2382,8 @@ Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al,
ASR::TranslationUnit_t *tu = ASR::down_cast2<ASR::TranslationUnit_t>(unit);
LFORTRAN_ASSERT(asr_verify(*tu));

auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module);
auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module,
ast_overload);
if (res2.ok) {
tu = res2.result;
} else {
Expand Down