Skip to content
Merged
4 changes: 2 additions & 2 deletions integration_tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@
"test_math1.py",
"test_math_02.py",
"test_c_interop_01.py",
"test_generics_01.py",
]

# CPython tests only
test_cpython = [
"test_generics_01.py",
"test_builtin_bin.py",
"test_builtin_hex.py",
"test_builtin_oct.py"
]

CUR_DIR = ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))
CUR_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))

def main():
if not os.path.exists(os.path.join(CUR_DIR, 'tmp')):
Expand Down
13 changes: 8 additions & 5 deletions integration_tests/test_generics_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ def test(a: bool) -> i32:
return -10


assert foo(2) == 4
assert foo(2, 10) == 20
assert foo("hello") == "lpython-hello"
assert test(10) == 20
assert test(False) == -test(True) and test(True) == 10
def check():
assert foo(2) == 4
assert foo(2, 10) == 20
assert foo("hello") == "lpython-hello"
assert test(10) == 20
assert test(False) == -test(True) and test(True) == 10

check()
120 changes: 95 additions & 25 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab,
return mod2;
}


template <class Derived>
class CommonVisitor : public AST::BaseVisitor<Derived> {
public:
Expand All @@ -156,10 +157,13 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
// The main module is stored directly in TranslationUnit, other modules are Modules
bool main_module;
PythonIntrinsicProcedures intrinsic_procedures;
std::map<int, ASR::symbol_t*> &ast_overload;

CommonVisitor(Allocator &al, SymbolTable *symbol_table,
diag::Diagnostics &diagnostics, bool main_module)
: diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module} {
diag::Diagnostics &diagnostics, bool main_module,
std::map<int, ASR::symbol_t*> &ast_overload)
: diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module},
ast_overload{ast_overload} {
current_module_dependencies.reserve(al, 4);
}

Expand Down Expand Up @@ -445,7 +449,7 @@ ASR::symbol_t* import_from_module(Allocator &al, ASR::Module_t *m, SymbolTable *
throw SemanticError("Only Subroutines, Functions and Variables are currently supported in 'import'",
loc);
}
// should not reach here
LFORTRAN_ASSERT(false);
return nullptr;
}

Expand All @@ -469,11 +473,13 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
std::map<SymbolTable*, ASR::accessType> assgn;
ASR::symbol_t *current_module_sym;
std::vector<std::string> excluded_from_symtab;
std::map<std::string, Vec<ASR::symbol_t* >> overload_defs;


SymbolTableVisitor(Allocator &al, SymbolTable *symbol_table,
diag::Diagnostics &diagnostics, bool main_module)
: CommonVisitor(al, symbol_table, diagnostics, main_module), is_derived_type{false} {}
diag::Diagnostics &diagnostics, bool main_module,
std::map<int, ASR::symbol_t*> &ast_overload)
: CommonVisitor(al, symbol_table, diagnostics, main_module, ast_overload), is_derived_type{false} {}


ASR::symbol_t* resolve_symbol(const Location &loc, const std::string &sub_name) {
Expand Down Expand Up @@ -522,7 +528,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
for (size_t i=0; i<x.n_body; i++) {
visit_stmt(*x.m_body[i]);
}

if (!overload_defs.empty()) {
create_GenericProcedure(x.base.base.loc);
}
global_scope = nullptr;
tmp = tmp0;
}
Expand All @@ -534,12 +542,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
Vec<ASR::expr_t*> args;
args.reserve(al, x.m_args.n_args);
current_procedure_abi_type = ASR::abiType::Source;
if (x.n_decorator_list == 1) {
AST::expr_t *dec = x.m_decorator_list[0];
if (AST::is_a<AST::Name_t>(*dec)) {
std::string name = AST::down_cast<AST::Name_t>(dec)->m_id;
if (name == "ccall") {
current_procedure_abi_type = ASR::abiType::BindC;
bool overload = false;
if (x.n_decorator_list > 0) {
for(size_t i=0; i<x.n_decorator_list; i++) {
AST::expr_t *dec = x.m_decorator_list[i];
if (AST::is_a<AST::Name_t>(*dec)) {
std::string name = AST::down_cast<AST::Name_t>(dec)->m_id;
if (name == "ccall") {
current_procedure_abi_type = ASR::abiType::BindC;
} else if (name == "overload") {
overload = true;
} else {
throw SemanticError("Decorator: " + name + " is not supported",
x.base.base.loc);
}
} else {
throw SemanticError("Unsupported Decorator type",
x.base.base.loc);
}
}
}
Expand Down Expand Up @@ -578,6 +597,18 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
var)));
}
std::string sym_name = x.m_name;
if (overload) {
std::string overload_number;
if (overload_defs.find(sym_name) == overload_defs.end()){
overload_number = "0";
Vec<ASR::symbol_t *> v;
v.reserve(al, 1);
overload_defs[sym_name] = v;
} else {
overload_number = std::to_string(overload_defs[sym_name].size());
}
sym_name = "__lpython_overloaded_" + overload_number + "__" + sym_name;
}
if (parent_scope->scope.find(sym_name) != parent_scope->scope.end()) {
throw SemanticError("Subroutine already defined", tmp->loc);
}
Expand Down Expand Up @@ -631,8 +662,23 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
s_access, deftype, bindc_name,
is_pure, is_module);
}
parent_scope->scope[sym_name] = ASR::down_cast<ASR::symbol_t>(tmp);
ASR::symbol_t * t = ASR::down_cast<ASR::symbol_t>(tmp);
parent_scope->scope[sym_name] = t;
current_scope = parent_scope;
if (overload) {
overload_defs[x.m_name].push_back(al, t);
ast_overload[(int64_t)&x] = t;
}
}

void create_GenericProcedure(const Location &loc) {
for(auto &p: overload_defs) {
std::string def_name = p.first;
tmp = ASR::make_GenericProcedure_t(al, loc, current_scope, s2c(al, def_name),
p.second.p, p.second.size(), ASR::accessType::Public);
ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t>(tmp);
current_scope->scope[def_name] = t;
}
}

void visit_ImportFrom(const AST::ImportFrom_t &x) {
Expand Down Expand Up @@ -724,9 +770,10 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
};

Result<ASR::asr_t*> symbol_table_visitor(Allocator &al, const AST::Module_t &ast,
diag::Diagnostics &diagnostics, bool main_module)
diag::Diagnostics &diagnostics, bool main_module,
std::map<int, ASR::symbol_t*> &ast_overload)
{
SymbolTableVisitor v(al, nullptr, diagnostics, main_module);
SymbolTableVisitor v(al, nullptr, diagnostics, main_module, ast_overload);
try {
v.visit_Module(ast);
} catch (const SemanticError &e) {
Expand All @@ -748,8 +795,9 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
ASR::asr_t *asr;
Vec<ASR::stmt_t*> *current_body;

BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics, bool main_module)
: CommonVisitor(al, nullptr, diagnostics, main_module), asr{unit} {}
BodyVisitor(Allocator &al, ASR::asr_t *unit, diag::Diagnostics &diagnostics,
bool main_module, std::map<int, ASR::symbol_t*> &ast_overload)
: CommonVisitor(al, nullptr, diagnostics, main_module, ast_overload), asr{unit} {}

// Transforms statements to a list of ASR statements
// In addition, it also inserts the following nodes if needed:
Expand Down Expand Up @@ -817,6 +865,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
} else if (ASR::is_a<ASR::Function_t>(*t)) {
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
handle_fn(x, *f);
} else if (ASR::is_a<ASR::GenericProcedure_t>(*t)) {
ASR::symbol_t *s = ast_overload[(int64_t)&x];
if (ASR::is_a<ASR::Subroutine_t>(*s)) {
handle_fn(x, *ASR::down_cast<ASR::Subroutine_t>(s));
} else if (ASR::is_a<ASR::Function_t>(*s)) {
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(s);
handle_fn(x, *f);
} else {
LFORTRAN_ASSERT(false);
}
} else {
LFORTRAN_ASSERT(false);
}
Expand Down Expand Up @@ -2108,8 +2166,15 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
x.base.base.loc);
}

ASR::symbol_t *s = current_scope->resolve_symbol(call_name);

ASR::symbol_t *s = current_scope->resolve_symbol(call_name), *s_generic = nullptr;
if (s!=nullptr && s->type == ASR::symbolType::GenericProcedure) {
ASR::GenericProcedure_t *p = ASR::down_cast<ASR::GenericProcedure_t>(s);
int idx = ASRUtils::select_generic_procedure(args, *p, x.base.base.loc,
[&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); });
// Create ExternalSymbol for procedures in different modules.
s_generic = s;
s = p->m_procs[idx];
}

if (!s) {
if (intrinsic_procedures.is_intrinsic(call_name)) {
Expand Down Expand Up @@ -2246,10 +2311,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
value = intrinsic_procedures.comptime_eval(call_name, al, x.base.base.loc, args);
}
tmp = ASR::make_FunctionCall_t(al, x.base.base.loc, stemp,
nullptr, args.p, args.size(), nullptr, 0, a_type, value, nullptr);
s_generic, args.p, args.size(), nullptr, 0, a_type, value, nullptr);
} else if(ASR::is_a<ASR::Subroutine_t>(*s)) {
tmp = ASR::make_SubroutineCall_t(al, x.base.base.loc, stemp,
nullptr, args.p, args.size(), nullptr);
s_generic, args.p, args.size(), nullptr);
} else {
throw SemanticError("Unsupported call type for " + call_name,
x.base.base.loc);
Expand All @@ -2265,9 +2330,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
Result<ASR::TranslationUnit_t*> body_visitor(Allocator &al,
const AST::Module_t &ast,
diag::Diagnostics &diagnostics,
ASR::asr_t *unit, bool main_module)
ASR::asr_t *unit, bool main_module,
std::map<int, ASR::symbol_t*> &ast_overload)
{
BodyVisitor b(al, unit, diagnostics, main_module);
BodyVisitor b(al, unit, diagnostics, main_module, ast_overload);
try {
b.visit_Module(ast);
} catch (const SemanticError &e) {
Expand Down Expand Up @@ -2301,10 +2367,13 @@ std::string pickle_python(AST::ast_t &ast, bool colors, bool indent) {
Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al,
AST::ast_t &ast, diag::Diagnostics &diagnostics, bool main_module)
{
std::map<int, ASR::symbol_t*> ast_overload;

AST::Module_t *ast_m = AST::down_cast2<AST::Module_t>(&ast);

ASR::asr_t *unit;
auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module);
auto res = symbol_table_visitor(al, *ast_m, diagnostics, main_module,
ast_overload);
if (res.ok) {
unit = res.result;
} else {
Expand All @@ -2313,7 +2382,8 @@ Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al,
ASR::TranslationUnit_t *tu = ASR::down_cast2<ASR::TranslationUnit_t>(unit);
LFORTRAN_ASSERT(asr_verify(*tu));

auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module);
auto res2 = body_visitor(al, *ast_m, diagnostics, unit, main_module,
ast_overload);
if (res2.ok) {
tu = res2.result;
} else {
Expand Down