Skip to content

Adding type parameters to functions without return values (cleaned history) #989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,6 @@ RUN(NAME test_str_comparison LABELS cpython llvm)
RUN(NAME test_bit_length LABELS cpython llvm)

RUN(NAME generics_01 LABELS cpython llvm)
RUN(NAME generics_02 LABELS cpython llvm)
RUN(NAME generics_array_01 LABELS llvm)
RUN(NAME test_statistics LABELS cpython llvm)
13 changes: 13 additions & 0 deletions integration_tests/generics_02.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ltypes import TypeVar

T = TypeVar('T')

def swap(x: T, y: T):
temp: T
temp = x
x = y
y = temp
print(x)
print(y)

swap(1,2)
29 changes: 29 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,22 @@ static inline ASR::ttype_t* duplicate_type_without_dims(Allocator& al, const ASR
return ASRUtils::TYPE(ASR::make_Integer_t(al, t->base.loc,
tnew->m_kind, nullptr, 0));
}
case ASR::ttypeType::Real: {
ASR::Real_t* tnew = ASR::down_cast<ASR::Real_t>(t);
return ASRUtils::TYPE(ASR::make_Real_t(al, t->base.loc,
tnew->m_kind, nullptr, 0));
}
case ASR::ttypeType::Character: {
ASR::Character_t* tnew = ASR::down_cast<ASR::Character_t>(t);
return ASRUtils::TYPE(ASR::make_Character_t(al, t->base.loc,
tnew->m_kind, tnew->m_len, tnew->m_len_expr,
nullptr, 0));
}
case ASR::ttypeType::TypeParameter: {
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
return ASRUtils::TYPE(ASR::make_TypeParameter_t(al, t->base.loc,
tp->m_param, nullptr, 0));
}
default : throw LCompilersException("Not implemented " + std::to_string(t->type));
}
}
Expand Down Expand Up @@ -1387,6 +1403,19 @@ static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type) {
}
}

static inline ASR::ttype_t* get_type_parameter(ASR::ttype_t* t) {
switch (t->type) {
case ASR::ttypeType::TypeParameter: {
return t;
}
case ASR::ttypeType::List: {
ASR::List_t *tl = ASR::down_cast<ASR::List_t>(t);
return get_type_parameter(tl->m_type);
}
default: throw LCompilersException("Cannot get type parameter from this type.");
}
}

class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {

private:
Expand Down
8 changes: 6 additions & 2 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
// Generate function prototypes
for (auto &item : x.m_global_scope->get_scope()) {
if (is_a<ASR::Function_t>(*item.second)) {
visit_Function(*ASR::down_cast<ASR::Function_t>(item.second));
if (ASR::down_cast<ASR::Function_t>(item.second)->n_type_params == 0) {
visit_Function(*ASR::down_cast<ASR::Function_t>(item.second));
}
}
}
prototype_only = false;
Expand All @@ -1033,7 +1035,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
// Then do all the procedures
for (auto &item : x.m_global_scope->get_scope()) {
if (is_a<ASR::Function_t>(*item.second)) {
visit_symbol(*item.second);
if (ASR::down_cast<ASR::Function_t>(item.second)->n_type_params == 0) {
visit_symbol(*item.second);
}
}
}

Expand Down
47 changes: 23 additions & 24 deletions src/libasr/pass/instantiate_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,20 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
public:
SymbolTable *current_scope;
std::map<std::string, ASR::ttype_t*> subs;
int new_function_num;
std::string new_func_name;

FunctionInstantiator(Allocator &al, std::map<std::string, ASR::ttype_t*> subs,
SymbolTable *current_scope, int new_function_num):
SymbolTable *current_scope, std::string new_func_name):
BaseExprStmtDuplicator(al),
current_scope{current_scope},
subs{subs},
new_function_num{new_function_num}
new_func_name{new_func_name}
{}

ASR::asr_t* instantiate_Function(ASR::Function_t &x) {
SymbolTable *parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);

std::string func_name = x.m_name;
func_name = "__lpython_generic_" + func_name + "_" + std::to_string(new_function_num);

Vec<ASR::expr_t*> args;
args.reserve(al, x.n_args);
for (size_t i=0; i<x.n_args; i++) {
Expand Down Expand Up @@ -57,19 +54,21 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var)));
}

ASR::Variable_t *return_var = ASR::down_cast<ASR::Variable_t>(
(ASR::down_cast<ASR::Var_t>(x.m_return_var))->m_v);
std::string return_var_name = return_var->m_name;
ASR::ttype_t *return_param_type = ASRUtils::expr_type(x.m_return_var);
ASR::ttype_t *return_type = ASR::is_a<ASR::TypeParameter_t>(*return_param_type) ?
subs[ASR::down_cast<ASR::TypeParameter_t>(return_param_type)->m_param] : return_param_type;
ASR::asr_t *new_return_var = ASR::make_Variable_t(al, return_var->base.base.loc,
current_scope, s2c(al, return_var_name), return_var->m_intent, nullptr, nullptr,
return_var->m_storage, return_type, return_var->m_abi, return_var->m_access,
return_var->m_presence, return_var->m_value_attr);
current_scope->add_symbol(return_var_name, ASR::down_cast<ASR::symbol_t>(new_return_var));
ASR::asr_t *new_return_var_ref = ASR::make_Var_t(al, x.base.base.loc,
current_scope->get_symbol(return_var_name));
ASR::expr_t *new_return_var_ref = nullptr;
if (x.m_return_var != nullptr) {
ASR::Variable_t *return_var = ASR::down_cast<ASR::Variable_t>(
(ASR::down_cast<ASR::Var_t>(x.m_return_var))->m_v);
std::string return_var_name = return_var->m_name;
ASR::ttype_t *return_param_type = ASRUtils::expr_type(x.m_return_var);
ASR::ttype_t *return_type = substitute_type(return_param_type);
ASR::asr_t *new_return_var = ASR::make_Variable_t(al, return_var->base.base.loc,
current_scope, s2c(al, return_var_name), return_var->m_intent, nullptr, nullptr,
return_var->m_storage, return_type, return_var->m_abi, return_var->m_access,
return_var->m_presence, return_var->m_value_attr);
current_scope->add_symbol(return_var_name, ASR::down_cast<ASR::symbol_t>(new_return_var));
new_return_var_ref = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc,
current_scope->get_symbol(return_var_name)));
}

// Rebuild the symbol table
for (auto const &sym_pair: x.m_symtab->get_scope()) {
Expand Down Expand Up @@ -104,16 +103,16 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti

ASR::asr_t *result = ASR::make_Function_t(
al, x.base.base.loc,
current_scope, s2c(al, func_name),
current_scope, s2c(al, new_func_name),
args.p, args.size(),
nullptr, 0,
body.p, body.size(),
ASRUtils::EXPR(new_return_var_ref),
new_return_var_ref,
func_abi, func_access, func_deftype, bindc_name,
false, false, false);

ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t>(result);
parent_scope->add_symbol(func_name, t);
parent_scope->add_symbol(new_func_name, t);
current_scope = parent_scope;

return result;
Expand Down Expand Up @@ -322,8 +321,8 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
};

ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, std::map<std::string, ASR::ttype_t*> subs,
SymbolTable *current_scope, int new_function_num, ASR::Function_t &func) {
FunctionInstantiator tf(al, subs, current_scope, new_function_num);
SymbolTable *current_scope, std::string new_func_name, ASR::Function_t &func) {
FunctionInstantiator tf(al, subs, current_scope, new_func_name);
ASR::asr_t *new_function = tf.instantiate_Function(func);
return ASR::down_cast<ASR::symbol_t>(new_function);
}
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/instantiate_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace LFortran {

ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, std::map<std::string, ASR::ttype_t*> subs,
SymbolTable *current_scope, int new_function_num, ASR::Function_t &func);
SymbolTable *current_scope, std::string new_func_name, ASR::Function_t &func);
}

#endif // LFORTRAN_PASS_TEMPLATE_VISITOR_H
150 changes: 73 additions & 77 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,10 +717,24 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
stemp = symtab->get_symbol(local_sym);
}
}
if (ASR::is_a<ASR::Function_t>(*s) &&
ASR::down_cast<ASR::Function_t>(s)->m_return_var != nullptr) {
if (ASR::is_a<ASR::Function_t>(*s)) {
ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
if (func->n_type_params == 0) {
if (func->n_type_params > 0) {
std::map<std::string, ASR::ttype_t*> subs;
for (size_t i=0; i<args.size(); i++) {
ASR::ttype_t *param_type = ASRUtils::expr_type(func->m_args[i]);
ASR::ttype_t *arg_type = ASRUtils::expr_type(args[i].m_value);
subs = check_type_substitution(subs, param_type, arg_type, loc);
}

ASR::symbol_t *t = get_generic_function(subs, *func);
std::string new_call_name = call_name;
if (ASR::is_a<ASR::Function_t>(*t)) {
new_call_name = (ASR::down_cast<ASR::Function_t>(t))->m_name;
}
return make_call_helper(al, t, current_scope, args, new_call_name, loc);
}
if (ASR::down_cast<ASR::Function_t>(s)->m_return_var != nullptr) {
ASR::ttype_t *a_type = nullptr;
if( func->m_elemental && args.size() == 1 &&
ASRUtils::is_array(ASRUtils::expr_type(args[0].m_value)) ) {
Expand Down Expand Up @@ -766,39 +780,25 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
return func_call_asr;
}
} else {
std::map<std::string, ASR::ttype_t*> subs;
for (size_t i=0; i<args.size(); i++) {
ASR::ttype_t *param_type = ASRUtils::expr_type(func->m_args[i]);
ASR::ttype_t *arg_type = ASRUtils::expr_type(args[i].m_value);
subs = check_type_substitution(subs, param_type, arg_type, loc);
}

ASR::symbol_t *t = get_generic_function(subs, *func);
std::string new_call_name = call_name;
if (ASR::is_a<ASR::Function_t>(*t)) {
new_call_name = (ASR::down_cast<ASR::Function_t>(t))->m_name;
ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
if (args.size() != func->n_args) {
std::string fnd = std::to_string(args.size());
std::string org = std::to_string(func->n_args);
diag.add(diag::Diagnostic(
"Number of arguments does not match in the function call",
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("(found: '" + fnd + "', expected: '" + org + "')",
{loc})
})
);
throw SemanticAbort();
}
return make_call_helper(al, t, current_scope, args, new_call_name, loc);
}
} else if (ASR::is_a<ASR::Function_t>(*s)) {
ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
if (args.size() != func->n_args) {
std::string fnd = std::to_string(args.size());
std::string org = std::to_string(func->n_args);
diag.add(diag::Diagnostic(
"Number of arguments does not match in the function call",
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("(found: '" + fnd + "', expected: '" + org + "')",
{loc})
})
);
throw SemanticAbort();
Vec<ASR::call_arg_t> args_new;
args_new.reserve(al, func->n_args);
visit_expr_list_with_cast(func->m_args, func->n_args, args_new, args);
return ASR::make_SubroutineCall_t(al, loc, stemp,
s_generic, args_new.p, args_new.size(), nullptr);
}
Vec<ASR::call_arg_t> args_new;
args_new.reserve(al, func->n_args);
visit_expr_list_with_cast(func->m_args, func->n_args, args_new, args);
return ASR::make_SubroutineCall_t(al, loc, stemp,
s_generic, args_new.p, args_new.size(), nullptr);
} else if(ASR::is_a<ASR::DerivedType_t>(*s)) {
Vec<ASR::expr_t*> args_new;
args_new.reserve(al, args.size());
Expand Down Expand Up @@ -891,8 +891,9 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
new_function_num = 0;
}
generic_func_nums[func_name] = new_function_num + 1;
generic_func_subs["__lpython_generic_" + func_name + "_" + std::to_string(new_function_num)] = subs;
t = pass_instantiate_generic_function(al, subs, current_scope, new_function_num, func);
std::string new_func_name = "__lpython_generic_" + func_name + "_" + std::to_string(new_function_num);
generic_func_subs[new_func_name] = subs;
t = pass_instantiate_generic_function(al, subs, current_scope, new_func_name, func);
return t;
}

Expand Down Expand Up @@ -2273,7 +2274,8 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
current_procedure_abi_type = ASR::abiType::Source;
bool current_procedure_interface = false;
bool overload = false;
std::set<std::string> ps;
Vec<ASR::ttype_t*> tps;
tps.reserve(al, x.m_args.n_args);
bool vectorize = false;
if (x.n_decorator_list > 0) {
for(size_t i=0; i<x.n_decorator_list; i++) {
Expand Down Expand Up @@ -2310,8 +2312,25 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
ASR::ttype_t *arg_type = ast_expr_to_asr_type(x.base.base.loc, *x.m_args.m_args[i].m_annotation);
// Set the function as generic if an argument is typed with a type parameter
if (ASRUtils::is_generic(*arg_type)) {
std::string param_name = ASRUtils::get_parameter_name(arg_type);
ps.insert(param_name);
ASR::ttype_t *new_tt = ASRUtils::duplicate_type_without_dims(al, ASRUtils::get_type_parameter(arg_type));
size_t current_size = tps.size();
if (current_size == 0) {
tps.push_back(al, new_tt);
} else {
bool not_found = true;
for (size_t i = 0; i < current_size; i++) {
ASR::TypeParameter_t *added_tp = ASR::down_cast<ASR::TypeParameter_t>(tps.p[i]);
std::string new_param = ASR::down_cast<ASR::TypeParameter_t>(new_tt)->m_param;
std::string added_param = added_tp->m_param;
if (added_param.compare(new_param) == 0) {
not_found = false;
break;
}
}
if (not_found) {
tps.push_back(al, new_tt);
}
}
}

std::string arg_s = arg;
Expand Down Expand Up @@ -2377,43 +2396,19 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
ASR::down_cast<ASR::symbol_t>(return_var));
ASR::asr_t *return_var_ref = ASR::make_Var_t(al, x.base.base.loc,
current_scope->get_symbol(return_var_name));
if (ps.size() > 0) {
Vec<ASR::ttype_t*> type_params;
type_params.reserve(al, ps.size());
for (auto &p: ps) {
std::string param = p;
ASR::ttype_t *type_p = ASRUtils::TYPE(ASR::make_TypeParameter_t(al,
x.base.base.loc, s2c(al, p), nullptr, 0));
type_params.push_back(al, type_p);
}
tmp = ASR::make_Function_t(
al, x.base.base.loc,
/* a_symtab */ current_scope,
/* a_name */ s2c(al, sym_name),
/* a_args */ args.p,
/* n_args */ args.size(),
/* a_type_params */ type_params.p,
/* n_type_params */ type_params.size(),
/* a_body */ nullptr,
/* n_body */ 0,
/* a_return_var */ ASRUtils::EXPR(return_var_ref),
current_procedure_abi_type,
s_access, deftype, bindc_name, vectorize, false, false);
} else {
tmp = ASR::make_Function_t(
al, x.base.base.loc,
/* a_symtab */ current_scope,
/* a_name */ s2c(al, sym_name),
/* a_args */ args.p,
/* n_args */ args.size(),
/* a_type_params */ nullptr,
/* n_type_params */ 0,
/* a_body */ nullptr,
/* n_body */ 0,
/* a_return_var */ ASRUtils::EXPR(return_var_ref),
current_procedure_abi_type,
s_access, deftype, bindc_name, vectorize, false, false);
}
tmp = ASR::make_Function_t(
al, x.base.base.loc,
/* a_symtab */ current_scope,
/* a_name */ s2c(al, sym_name),
/* a_args */ args.p,
/* n_args */ args.size(),
/* a_type_params */ tps.p,
/* n_type_params */ tps.size(),
/* a_body */ nullptr,
/* n_body */ 0,
/* a_return_var */ ASRUtils::EXPR(return_var_ref),
current_procedure_abi_type,
s_access, deftype, bindc_name, vectorize, false, false);
} else {
throw SemanticError("Return variable must be an identifier (Name AST node) or an array (Subscript AST node)",
x.m_returns->base.loc);
Expand All @@ -2426,7 +2421,8 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
/* a_name */ s2c(al, sym_name),
/* a_args */ args.p,
/* n_args */ args.size(),
nullptr, 0,
/* a_type_params */ tps.p,
/* n_type_params */ tps.size(),
/* a_body */ nullptr,
/* n_body */ 0,
nullptr,
Expand Down
Loading