Skip to content

Commit 37e2685

Browse files
authored
Merge pull request #187 from Smit-create/overload3
Add overload support in lpython
2 parents 8dc7035 + ec200f5 commit 37e2685

File tree

3 files changed

+105
-32
lines changed

3 files changed

+105
-32
lines changed

integration_tests/run_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,17 @@
2929
"test_math1.py",
3030
"test_math_02.py",
3131
"test_c_interop_01.py",
32+
"test_generics_01.py",
3233
]
3334

3435
# CPython tests only
3536
test_cpython = [
36-
"test_generics_01.py",
3737
"test_builtin_bin.py",
3838
"test_builtin_hex.py",
3939
"test_builtin_oct.py"
4040
]
4141

42-
CUR_DIR = ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))
42+
CUR_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))
4343

4444
def main():
4545
if not os.path.exists(os.path.join(CUR_DIR, 'tmp')):

integration_tests/test_generics_01.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ def test(a: bool) -> i32:
2323
return -10
2424

2525

26-
assert foo(2) == 4
27-
assert foo(2, 10) == 20
28-
assert foo("hello") == "lpython-hello"
29-
assert test(10) == 20
30-
assert test(False) == -test(True) and test(True) == 10
26+
def check():
27+
assert foo(2) == 4
28+
assert foo(2, 10) == 20
29+
assert foo("hello") == "lpython-hello"
30+
assert test(10) == 20
31+
assert test(False) == -test(True) and test(True) == 10
32+
33+
check()

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 95 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab,
139139
return mod2;
140140
}
141141

142+
142143
template <class Derived>
143144
class CommonVisitor : public AST::BaseVisitor<Derived> {
144145
public:
@@ -156,10 +157,13 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
156157
// The main module is stored directly in TranslationUnit, other modules are Modules
157158
bool main_module;
158159
PythonIntrinsicProcedures intrinsic_procedures;
160+
std::map<int, ASR::symbol_t*> &ast_overload;
159161

160162
CommonVisitor(Allocator &al, SymbolTable *symbol_table,
161-
diag::Diagnostics &diagnostics, bool main_module)
162-
: diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module} {
163+
diag::Diagnostics &diagnostics, bool main_module,
164+
std::map<int, ASR::symbol_t*> &ast_overload)
165+
: diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module},
166+
ast_overload{ast_overload} {
163167
current_module_dependencies.reserve(al, 4);
164168
}
165169

@@ -445,7 +449,7 @@ ASR::symbol_t* import_from_module(Allocator &al, ASR::Module_t *m, SymbolTable *
445449
throw SemanticError("Only Subroutines, Functions and Variables are currently supported in 'import'",
446450
loc);
447451
}
448-
// should not reach here
452+
LFORTRAN_ASSERT(false);
449453
return nullptr;
450454
}
451455

@@ -469,11 +473,13 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
469473
std::map<SymbolTable*, ASR::accessType> assgn;
470474
ASR::symbol_t *current_module_sym;
471475
std::vector<std::string> excluded_from_symtab;
476+
std::map<std::string, Vec<ASR::symbol_t* >> overload_defs;
472477

473478

474479
SymbolTableVisitor(Allocator &al, SymbolTable *symbol_table,
475-
diag::Diagnostics &diagnostics, bool main_module)
476-
: CommonVisitor(al, symbol_table, diagnostics, main_module), is_derived_type{false} {}
480+
diag::Diagnostics &diagnostics, bool main_module,
481+
std::map<int, ASR::symbol_t*> &ast_overload)
482+
: CommonVisitor(al, symbol_table, diagnostics, main_module, ast_overload), is_derived_type{false} {}
477483

478484

479485
ASR::symbol_t* resolve_symbol(const Location &loc, const std::string &sub_name) {
@@ -522,7 +528,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
522528
for (size_t i=0; i<x.n_body; i++) {
523529
visit_stmt(*x.m_body[i]);
524530
}
525-
531+
if (!overload_defs.empty()) {
532+
create_GenericProcedure(x.base.base.loc);
533+
}
526534
global_scope = nullptr;
527535
tmp = tmp0;
528536
}
@@ -534,12 +542,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
534542
Vec<ASR::expr_t*> args;
535543
args.reserve(al, x.m_args.n_args);
536544
current_procedure_abi_type = ASR::abiType::Source;
537-
if (x.n_decorator_list == 1) {
538-
AST::expr_t *dec = x.m_decorator_list[0];
539-
if (AST::is_a<AST::Name_t>(*dec)) {
540-
std::string name = AST::down_cast<AST::Name_t>(dec)->m_id;
541-
if (name == "ccall") {
542-
current_procedure_abi_type = ASR::abiType::BindC;
545+
bool overload = false;
546+
if (x.n_decorator_list > 0) {
547+
for(size_t i=0; i<x.n_decorator_list; i++) {
548+
AST::expr_t *dec = x.m_decorator_list[i];
549+
if (AST::is_a<AST::Name_t>(*dec)) {
550+
std::string name = AST::down_cast<AST::Name_t>(dec)->m_id;
551+
if (name == "ccall") {
552+
current_procedure_abi_type = ASR::abiType::BindC;
553+
} else if (name == "overload") {
554+
overload = true;
555+
} else {
556+
throw SemanticError("Decorator: " + name + " is not supported",
557+
x.base.base.loc);
558+
}
559+
} else {
560+
throw SemanticError("Unsupported Decorator type",
561+
x.base.base.loc);
543562
}
544563
}
545564
}
@@ -578,6 +597,18 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
578597
var)));
579598
}
580599
std::string sym_name = x.m_name;
600+
if (overload) {
601+
std::string overload_number;
602+
if (overload_defs.find(sym_name) == overload_defs.end()){
603+
overload_number = "0";
604+
Vec<ASR::symbol_t *> v;
605+
v.reserve(al, 1);
606+
overload_defs[sym_name] = v;
607+
} else {
608+
overload_number = std::to_string(overload_defs[sym_name].size());
609+
}
610+
sym_name = "__lpython_overloaded_" + overload_number + "__" + sym_name;
611+
}
581612
if (parent_scope->scope.find(sym_name) != parent_scope->scope.end()) {
582613
throw SemanticError("Subroutine already defined", tmp->loc);
583614
}
@@ -631,8 +662,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
631662
s_access, deftype, bindc_name,
632663
is_pure, is_module);
633664
}
634-
parent_scope->scope[sym_name] = ASR::down_cast<ASR::symbol_t>(tmp);
665+
ASR::symbol_t * t = ASR::down_cast<ASR::symbol_t>(tmp);
666+
parent_scope->scope[sym_name] = t;
635667
current_scope = parent_scope;
668+
if (overload) {
669+
overload_defs[x.m_name].push_back(al, t);
670+
ast_overload[(int64_t)&x] = t;
671+
}
672+
}
673+
674+
void create_GenericProcedure(const Location &loc) {
675+
for(auto &p: overload_defs) {
676+
std::string def_name = p.first;
677+
tmp = ASR::make_GenericProcedure_t(al, loc, current_scope, s2c(al, def_name),
678+
p.second.p, p.second.size(), ASR::accessType::Public);
679+
ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t>(tmp);
680+
current_scope->scope[def_name] = t;
681+
}
636682
}
637683

638684
void visit_ImportFrom(const AST::ImportFrom_t &x) {
@@ -724,9 +770,10 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
724770
};
725771

726772
Result<ASR::asr_t*> symbol_table_visitor(Allocator &al, const AST::Module_t &ast,
727-
diag::Diagnostics &diagnostics, bool main_module)
773+
diag::Diagnostics &diagnostics, bool main_module,
774+
std::map<int, ASR::symbol_t*> &ast_overload)
728775
{
729-
SymbolTableVisitor v(al, nullptr, diagnostics, main_module);
776+
SymbolTableVisitor v(al, nullptr, diagnostics, main_module, ast_overload);
730777
try {
731778
v.visit_Module(ast);
732779
} catch (const SemanticError &e) {
@@ -748,8 +795,9 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
748795
ASR::asr_t *asr;
749796
Vec<ASR::stmt_t*> *current_body;
750797

751-
BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics, bool main_module)
752-
: CommonVisitor(al, nullptr, diagnostics, main_module), asr{unit} {}
798+
BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics,
799+
bool main_module, std::map<int, ASR::symbol_t*> &ast_overload)
800+
: CommonVisitor(al, nullptr, diagnostics, main_module, ast_overload), asr{unit} {}
753801

754802
// Transforms statements to a list of ASR statements
755803
// In addition, it also inserts the following nodes if needed:
@@ -817,6 +865,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
817865
} else if (ASR::is_a<ASR::Function_t>(*t)) {
818866
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
819867
handle_fn(x, *f);
868+
} else if (ASR::is_a<ASR::GenericProcedure_t>(*t)) {
869+
ASR::symbol_t *s = ast_overload[(int64_t)&x];
870+
if (ASR::is_a<ASR::Subroutine_t>(*s)) {
871+
handle_fn(x, *ASR::down_cast<ASR::Subroutine_t>(s));
872+
} else if (ASR::is_a<ASR::Function_t>(*s)) {
873+
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(s);
874+
handle_fn(x, *f);
875+
} else {
876+
LFORTRAN_ASSERT(false);
877+
}
820878
} else {
821879
LFORTRAN_ASSERT(false);
822880
}
@@ -2108,8 +2166,15 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
21082166
x.base.base.loc);
21092167
}
21102168

2111-
ASR::symbol_t *s = current_scope->resolve_symbol(call_name);
2112-
2169+
ASR::symbol_t *s = current_scope->resolve_symbol(call_name), *s_generic = nullptr;
2170+
if (s!=nullptr && s->type == ASR::symbolType::GenericProcedure) {
2171+
ASR::GenericProcedure_t *p = ASR::down_cast<ASR::GenericProcedure_t>(s);
2172+
int idx = ASRUtils::select_generic_procedure(args, *p, x.base.base.loc,
2173+
[&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); });
2174+
// Create ExternalSymbol for procedures in different modules.
2175+
s_generic = s;
2176+
s = p->m_procs[idx];
2177+
}
21132178

21142179
if (!s) {
21152180
if (intrinsic_procedures.is_intrinsic(call_name)) {
@@ -2246,10 +2311,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
22462311
value = intrinsic_procedures.comptime_eval(call_name, al, x.base.base.loc, args);
22472312
}
22482313
tmp = ASR::make_FunctionCall_t(al, x.base.base.loc, stemp,
2249-
nullptr, args.p, args.size(), nullptr, 0, a_type, value, nullptr);
2314+
s_generic, args.p, args.size(), nullptr, 0, a_type, value, nullptr);
22502315
} else if(ASR::is_a<ASR::Subroutine_t>(*s)) {
22512316
tmp = ASR::make_SubroutineCall_t(al, x.base.base.loc, stemp,
2252-
nullptr, args.p, args.size(), nullptr);
2317+
s_generic, args.p, args.size(), nullptr);
22532318
} else {
22542319
throw SemanticError("Unsupported call type for " + call_name,
22552320
x.base.base.loc);
@@ -2265,9 +2330,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
22652330
Result<ASR::TranslationUnit_t*> body_visitor(Allocator &al,
22662331
const AST::Module_t &ast,
22672332
diag::Diagnostics &diagnostics,
2268-
ASR::asr_t *unit, bool main_module)
2333+
ASR::asr_t *unit, bool main_module,
2334+
std::map<int, ASR::symbol_t*> &ast_overload)
22692335
{
2270-
BodyVisitor b(al, unit, diagnostics, main_module);
2336+
BodyVisitor b(al, unit, diagnostics, main_module, ast_overload);
22712337
try {
22722338
b.visit_Module(ast);
22732339
} catch (const SemanticError &e) {
@@ -2301,10 +2367,13 @@ std::string pickle_python(AST::ast_t &ast, bool colors, bool indent) {
23012367
Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al,
23022368
AST::ast_t &ast, diag::Diagnostics &diagnostics, bool main_module)
23032369
{
2370+
std::map<int, ASR::symbol_t*> ast_overload;
2371+
23042372
AST::Module_t *ast_m = AST::down_cast2<AST::Module_t>(&ast);
23052373

23062374
ASR::asr_t *unit;
2307-
auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module);
2375+
auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module,
2376+
ast_overload);
23082377
if (res.ok) {
23092378
unit = res.result;
23102379
} else {
@@ -2313,7 +2382,8 @@ Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al,
23132382
ASR::TranslationUnit_t *tu = ASR::down_cast2<ASR::TranslationUnit_t>(unit);
23142383
LFORTRAN_ASSERT(asr_verify(*tu));
23152384

2316-
auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module);
2385+
auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module,
2386+
ast_overload);
23172387
if (res2.ok) {
23182388
tu = res2.result;
23192389
} else {

0 commit comments

Comments
 (0)