Skip to content

Commit c982a7a

Browse files
authored
Merge pull request #989 from ansharlubis/subroutine-template-clean
Adding type parameters to functions without return values (cleaned history)
2 parents 9c7e557 + 47a38f9 commit c982a7a

11 files changed

+164
-117
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,5 +222,6 @@ RUN(NAME test_str_comparison LABELS cpython llvm)
222222
RUN(NAME test_bit_length LABELS cpython llvm)
223223

224224
RUN(NAME generics_01 LABELS cpython llvm)
225+
RUN(NAME generics_02 LABELS cpython llvm)
225226
RUN(NAME generics_array_01 LABELS llvm)
226227
RUN(NAME test_statistics LABELS cpython llvm)

integration_tests/generics_02.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from ltypes import TypeVar
2+
3+
T = TypeVar('T')
4+
5+
def swap(x: T, y: T):
6+
temp: T
7+
temp = x
8+
x = y
9+
y = temp
10+
print(x)
11+
print(y)
12+
13+
swap(1,2)

src/libasr/asr_utils.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,22 @@ static inline ASR::ttype_t* duplicate_type_without_dims(Allocator& al, const ASR
11691169
return ASRUtils::TYPE(ASR::make_Integer_t(al, t->base.loc,
11701170
tnew->m_kind, nullptr, 0));
11711171
}
1172+
case ASR::ttypeType::Real: {
1173+
ASR::Real_t* tnew = ASR::down_cast<ASR::Real_t>(t);
1174+
return ASRUtils::TYPE(ASR::make_Real_t(al, t->base.loc,
1175+
tnew->m_kind, nullptr, 0));
1176+
}
1177+
case ASR::ttypeType::Character: {
1178+
ASR::Character_t* tnew = ASR::down_cast<ASR::Character_t>(t);
1179+
return ASRUtils::TYPE(ASR::make_Character_t(al, t->base.loc,
1180+
tnew->m_kind, tnew->m_len, tnew->m_len_expr,
1181+
nullptr, 0));
1182+
}
1183+
case ASR::ttypeType::TypeParameter: {
1184+
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
1185+
return ASRUtils::TYPE(ASR::make_TypeParameter_t(al, t->base.loc,
1186+
tp->m_param, nullptr, 0));
1187+
}
11721188
default : throw LCompilersException("Not implemented " + std::to_string(t->type));
11731189
}
11741190
}
@@ -1387,6 +1403,19 @@ static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type) {
13871403
}
13881404
}
13891405

1406+
static inline ASR::ttype_t* get_type_parameter(ASR::ttype_t* t) {
1407+
switch (t->type) {
1408+
case ASR::ttypeType::TypeParameter: {
1409+
return t;
1410+
}
1411+
case ASR::ttypeType::List: {
1412+
ASR::List_t *tl = ASR::down_cast<ASR::List_t>(t);
1413+
return get_type_parameter(tl->m_type);
1414+
}
1415+
default: throw LCompilersException("Cannot get type parameter from this type.");
1416+
}
1417+
}
1418+
13901419
class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {
13911420

13921421
private:

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
10131013
// Generate function prototypes
10141014
for (auto &item : x.m_global_scope->get_scope()) {
10151015
if (is_a<ASR::Function_t>(*item.second)) {
1016-
visit_Function(*ASR::down_cast<ASR::Function_t>(item.second));
1016+
if (ASR::down_cast<ASR::Function_t>(item.second)->n_type_params == 0) {
1017+
visit_Function(*ASR::down_cast<ASR::Function_t>(item.second));
1018+
}
10171019
}
10181020
}
10191021
prototype_only = false;
@@ -1033,7 +1035,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
10331035
// Then do all the procedures
10341036
for (auto &item : x.m_global_scope->get_scope()) {
10351037
if (is_a<ASR::Function_t>(*item.second)) {
1036-
visit_symbol(*item.second);
1038+
if (ASR::down_cast<ASR::Function_t>(item.second)->n_type_params == 0) {
1039+
visit_symbol(*item.second);
1040+
}
10371041
}
10381042
}
10391043

src/libasr/pass/instantiate_template.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,20 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
1111
public:
1212
SymbolTable *current_scope;
1313
std::map<std::string, ASR::ttype_t*> subs;
14-
int new_function_num;
14+
std::string new_func_name;
1515

1616
FunctionInstantiator(Allocator &al, std::map<std::string, ASR::ttype_t*> subs,
17-
SymbolTable *current_scope, int new_function_num):
17+
SymbolTable *current_scope, std::string new_func_name):
1818
BaseExprStmtDuplicator(al),
1919
current_scope{current_scope},
2020
subs{subs},
21-
new_function_num{new_function_num}
21+
new_func_name{new_func_name}
2222
{}
2323

2424
ASR::asr_t* instantiate_Function(ASR::Function_t &x) {
2525
SymbolTable *parent_scope = current_scope;
2626
current_scope = al.make_new<SymbolTable>(parent_scope);
2727

28-
std::string func_name = x.m_name;
29-
func_name = "__lpython_generic_" + func_name + "_" + std::to_string(new_function_num);
30-
3128
Vec<ASR::expr_t*> args;
3229
args.reserve(al, x.n_args);
3330
for (size_t i=0; i<x.n_args; i++) {
@@ -57,19 +54,21 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
5754
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var)));
5855
}
5956

60-
ASR::Variable_t *return_var = ASR::down_cast<ASR::Variable_t>(
61-
(ASR::down_cast<ASR::Var_t>(x.m_return_var))->m_v);
62-
std::string return_var_name = return_var->m_name;
63-
ASR::ttype_t *return_param_type = ASRUtils::expr_type(x.m_return_var);
64-
ASR::ttype_t *return_type = ASR::is_a<ASR::TypeParameter_t>(*return_param_type) ?
65-
subs[ASR::down_cast<ASR::TypeParameter_t>(return_param_type)->m_param] : return_param_type;
66-
ASR::asr_t *new_return_var = ASR::make_Variable_t(al, return_var->base.base.loc,
67-
current_scope, s2c(al, return_var_name), return_var->m_intent, nullptr, nullptr,
68-
return_var->m_storage, return_type, return_var->m_abi, return_var->m_access,
69-
return_var->m_presence, return_var->m_value_attr);
70-
current_scope->add_symbol(return_var_name, ASR::down_cast<ASR::symbol_t>(new_return_var));
71-
ASR::asr_t *new_return_var_ref = ASR::make_Var_t(al, x.base.base.loc,
72-
current_scope->get_symbol(return_var_name));
57+
ASR::expr_t *new_return_var_ref = nullptr;
58+
if (x.m_return_var != nullptr) {
59+
ASR::Variable_t *return_var = ASR::down_cast<ASR::Variable_t>(
60+
(ASR::down_cast<ASR::Var_t>(x.m_return_var))->m_v);
61+
std::string return_var_name = return_var->m_name;
62+
ASR::ttype_t *return_param_type = ASRUtils::expr_type(x.m_return_var);
63+
ASR::ttype_t *return_type = substitute_type(return_param_type);
64+
ASR::asr_t *new_return_var = ASR::make_Variable_t(al, return_var->base.base.loc,
65+
current_scope, s2c(al, return_var_name), return_var->m_intent, nullptr, nullptr,
66+
return_var->m_storage, return_type, return_var->m_abi, return_var->m_access,
67+
return_var->m_presence, return_var->m_value_attr);
68+
current_scope->add_symbol(return_var_name, ASR::down_cast<ASR::symbol_t>(new_return_var));
69+
new_return_var_ref = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc,
70+
current_scope->get_symbol(return_var_name)));
71+
}
7372

7473
// Rebuild the symbol table
7574
for (auto const &sym_pair: x.m_symtab->get_scope()) {
@@ -104,16 +103,16 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
104103

105104
ASR::asr_t *result = ASR::make_Function_t(
106105
al, x.base.base.loc,
107-
current_scope, s2c(al, func_name),
106+
current_scope, s2c(al, new_func_name),
108107
args.p, args.size(),
109108
nullptr, 0,
110109
body.p, body.size(),
111-
ASRUtils::EXPR(new_return_var_ref),
110+
new_return_var_ref,
112111
func_abi, func_access, func_deftype, bindc_name,
113112
false, false, false);
114113

115114
ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t>(result);
116-
parent_scope->add_symbol(func_name, t);
115+
parent_scope->add_symbol(new_func_name, t);
117116
current_scope = parent_scope;
118117

119118
return result;
@@ -322,8 +321,8 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator<FunctionInstanti
322321
};
323322

324323
ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, std::map<std::string, ASR::ttype_t*> subs,
325-
SymbolTable *current_scope, int new_function_num, ASR::Function_t &func) {
326-
FunctionInstantiator tf(al, subs, current_scope, new_function_num);
324+
SymbolTable *current_scope, std::string new_func_name, ASR::Function_t &func) {
325+
FunctionInstantiator tf(al, subs, current_scope, new_func_name);
327326
ASR::asr_t *new_function = tf.instantiate_Function(func);
328327
return ASR::down_cast<ASR::symbol_t>(new_function);
329328
}

src/libasr/pass/instantiate_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
namespace LFortran {
77

88
ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, std::map<std::string, ASR::ttype_t*> subs,
9-
SymbolTable *current_scope, int new_function_num, ASR::Function_t &func);
9+
SymbolTable *current_scope, std::string new_func_name, ASR::Function_t &func);
1010
}
1111

1212
#endif // LFORTRAN_PASS_TEMPLATE_VISITOR_H

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 73 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,24 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
717717
stemp = symtab->get_symbol(local_sym);
718718
}
719719
}
720-
if (ASR::is_a<ASR::Function_t>(*s) &&
721-
ASR::down_cast<ASR::Function_t>(s)->m_return_var != nullptr) {
720+
if (ASR::is_a<ASR::Function_t>(*s)) {
722721
ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
723-
if (func->n_type_params == 0) {
722+
if (func->n_type_params > 0) {
723+
std::map<std::string, ASR::ttype_t*> subs;
724+
for (size_t i=0; i<args.size(); i++) {
725+
ASR::ttype_t *param_type = ASRUtils::expr_type(func->m_args[i]);
726+
ASR::ttype_t *arg_type = ASRUtils::expr_type(args[i].m_value);
727+
subs = check_type_substitution(subs, param_type, arg_type, loc);
728+
}
729+
730+
ASR::symbol_t *t = get_generic_function(subs, *func);
731+
std::string new_call_name = call_name;
732+
if (ASR::is_a<ASR::Function_t>(*t)) {
733+
new_call_name = (ASR::down_cast<ASR::Function_t>(t))->m_name;
734+
}
735+
return make_call_helper(al, t, current_scope, args, new_call_name, loc);
736+
}
737+
if (ASR::down_cast<ASR::Function_t>(s)->m_return_var != nullptr) {
724738
ASR::ttype_t *a_type = nullptr;
725739
if( func->m_elemental && args.size() == 1 &&
726740
ASRUtils::is_array(ASRUtils::expr_type(args[0].m_value)) ) {
@@ -766,39 +780,25 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
766780
return func_call_asr;
767781
}
768782
} else {
769-
std::map<std::string, ASR::ttype_t*> subs;
770-
for (size_t i=0; i<args.size(); i++) {
771-
ASR::ttype_t *param_type = ASRUtils::expr_type(func->m_args[i]);
772-
ASR::ttype_t *arg_type = ASRUtils::expr_type(args[i].m_value);
773-
subs = check_type_substitution(subs, param_type, arg_type, loc);
774-
}
775-
776-
ASR::symbol_t *t = get_generic_function(subs, *func);
777-
std::string new_call_name = call_name;
778-
if (ASR::is_a<ASR::Function_t>(*t)) {
779-
new_call_name = (ASR::down_cast<ASR::Function_t>(t))->m_name;
783+
ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
784+
if (args.size() != func->n_args) {
785+
std::string fnd = std::to_string(args.size());
786+
std::string org = std::to_string(func->n_args);
787+
diag.add(diag::Diagnostic(
788+
"Number of arguments does not match in the function call",
789+
diag::Level::Error, diag::Stage::Semantic, {
790+
diag::Label("(found: '" + fnd + "', expected: '" + org + "')",
791+
{loc})
792+
})
793+
);
794+
throw SemanticAbort();
780795
}
781-
return make_call_helper(al, t, current_scope, args, new_call_name, loc);
782-
}
783-
} else if (ASR::is_a<ASR::Function_t>(*s)) {
784-
ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
785-
if (args.size() != func->n_args) {
786-
std::string fnd = std::to_string(args.size());
787-
std::string org = std::to_string(func->n_args);
788-
diag.add(diag::Diagnostic(
789-
"Number of arguments does not match in the function call",
790-
diag::Level::Error, diag::Stage::Semantic, {
791-
diag::Label("(found: '" + fnd + "', expected: '" + org + "')",
792-
{loc})
793-
})
794-
);
795-
throw SemanticAbort();
796+
Vec<ASR::call_arg_t> args_new;
797+
args_new.reserve(al, func->n_args);
798+
visit_expr_list_with_cast(func->m_args, func->n_args, args_new, args);
799+
return ASR::make_SubroutineCall_t(al, loc, stemp,
800+
s_generic, args_new.p, args_new.size(), nullptr);
796801
}
797-
Vec<ASR::call_arg_t> args_new;
798-
args_new.reserve(al, func->n_args);
799-
visit_expr_list_with_cast(func->m_args, func->n_args, args_new, args);
800-
return ASR::make_SubroutineCall_t(al, loc, stemp,
801-
s_generic, args_new.p, args_new.size(), nullptr);
802802
} else if(ASR::is_a<ASR::DerivedType_t>(*s)) {
803803
Vec<ASR::expr_t*> args_new;
804804
args_new.reserve(al, args.size());
@@ -891,8 +891,9 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
891891
new_function_num = 0;
892892
}
893893
generic_func_nums[func_name] = new_function_num + 1;
894-
generic_func_subs["__lpython_generic_" + func_name + "_" + std::to_string(new_function_num)] = subs;
895-
t = pass_instantiate_generic_function(al, subs, current_scope, new_function_num, func);
894+
std::string new_func_name = "__lpython_generic_" + func_name + "_" + std::to_string(new_function_num);
895+
generic_func_subs[new_func_name] = subs;
896+
t = pass_instantiate_generic_function(al, subs, current_scope, new_func_name, func);
896897
return t;
897898
}
898899

@@ -2273,7 +2274,8 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
22732274
current_procedure_abi_type = ASR::abiType::Source;
22742275
bool current_procedure_interface = false;
22752276
bool overload = false;
2276-
std::set<std::string> ps;
2277+
Vec<ASR::ttype_t*> tps;
2278+
tps.reserve(al, x.m_args.n_args);
22772279
bool vectorize = false;
22782280
if (x.n_decorator_list > 0) {
22792281
for(size_t i=0; i<x.n_decorator_list; i++) {
@@ -2310,8 +2312,25 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
23102312
ASR::ttype_t *arg_type = ast_expr_to_asr_type(x.base.base.loc, *x.m_args.m_args[i].m_annotation);
23112313
// Set the function as generic if an argument is typed with a type parameter
23122314
if (ASRUtils::is_generic(*arg_type)) {
2313-
std::string param_name = ASRUtils::get_parameter_name(arg_type);
2314-
ps.insert(param_name);
2315+
ASR::ttype_t *new_tt = ASRUtils::duplicate_type_without_dims(al, ASRUtils::get_type_parameter(arg_type));
2316+
size_t current_size = tps.size();
2317+
if (current_size == 0) {
2318+
tps.push_back(al, new_tt);
2319+
} else {
2320+
bool not_found = true;
2321+
for (size_t i = 0; i < current_size; i++) {
2322+
ASR::TypeParameter_t *added_tp = ASR::down_cast<ASR::TypeParameter_t>(tps.p[i]);
2323+
std::string new_param = ASR::down_cast<ASR::TypeParameter_t>(new_tt)->m_param;
2324+
std::string added_param = added_tp->m_param;
2325+
if (added_param.compare(new_param) == 0) {
2326+
not_found = false;
2327+
break;
2328+
}
2329+
}
2330+
if (not_found) {
2331+
tps.push_back(al, new_tt);
2332+
}
2333+
}
23152334
}
23162335

23172336
std::string arg_s = arg;
@@ -2377,43 +2396,19 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
23772396
ASR::down_cast<ASR::symbol_t>(return_var));
23782397
ASR::asr_t *return_var_ref = ASR::make_Var_t(al, x.base.base.loc,
23792398
current_scope->get_symbol(return_var_name));
2380-
if (ps.size() > 0) {
2381-
Vec<ASR::ttype_t*> type_params;
2382-
type_params.reserve(al, ps.size());
2383-
for (auto &p: ps) {
2384-
std::string param = p;
2385-
ASR::ttype_t *type_p = ASRUtils::TYPE(ASR::make_TypeParameter_t(al,
2386-
x.base.base.loc, s2c(al, p), nullptr, 0));
2387-
type_params.push_back(al, type_p);
2388-
}
2389-
tmp = ASR::make_Function_t(
2390-
al, x.base.base.loc,
2391-
/* a_symtab */ current_scope,
2392-
/* a_name */ s2c(al, sym_name),
2393-
/* a_args */ args.p,
2394-
/* n_args */ args.size(),
2395-
/* a_type_params */ type_params.p,
2396-
/* n_type_params */ type_params.size(),
2397-
/* a_body */ nullptr,
2398-
/* n_body */ 0,
2399-
/* a_return_var */ ASRUtils::EXPR(return_var_ref),
2400-
current_procedure_abi_type,
2401-
s_access, deftype, bindc_name, vectorize, false, false);
2402-
} else {
2403-
tmp = ASR::make_Function_t(
2404-
al, x.base.base.loc,
2405-
/* a_symtab */ current_scope,
2406-
/* a_name */ s2c(al, sym_name),
2407-
/* a_args */ args.p,
2408-
/* n_args */ args.size(),
2409-
/* a_type_params */ nullptr,
2410-
/* n_type_params */ 0,
2411-
/* a_body */ nullptr,
2412-
/* n_body */ 0,
2413-
/* a_return_var */ ASRUtils::EXPR(return_var_ref),
2414-
current_procedure_abi_type,
2415-
s_access, deftype, bindc_name, vectorize, false, false);
2416-
}
2399+
tmp = ASR::make_Function_t(
2400+
al, x.base.base.loc,
2401+
/* a_symtab */ current_scope,
2402+
/* a_name */ s2c(al, sym_name),
2403+
/* a_args */ args.p,
2404+
/* n_args */ args.size(),
2405+
/* a_type_params */ tps.p,
2406+
/* n_type_params */ tps.size(),
2407+
/* a_body */ nullptr,
2408+
/* n_body */ 0,
2409+
/* a_return_var */ ASRUtils::EXPR(return_var_ref),
2410+
current_procedure_abi_type,
2411+
s_access, deftype, bindc_name, vectorize, false, false);
24172412
} else {
24182413
throw SemanticError("Return variable must be an identifier (Name AST node) or an array (Subscript AST node)",
24192414
x.m_returns->base.loc);
@@ -2426,7 +2421,8 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
24262421
/* a_name */ s2c(al, sym_name),
24272422
/* a_args */ args.p,
24282423
/* n_args */ args.size(),
2429-
nullptr, 0,
2424+
/* a_type_params */ tps.p,
2425+
/* n_type_params */ tps.size(),
24302426
/* a_body */ nullptr,
24312427
/* n_body */ 0,
24322428
nullptr,

0 commit comments

Comments
 (0)