diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index e61654d7d1..022350976c 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -218,5 +218,6 @@ RUN(NAME test_str_comparison LABELS cpython llvm) RUN(NAME test_bit_length LABELS cpython llvm) RUN(NAME generics_01 LABELS cpython llvm) +RUN(NAME generics_02 LABELS cpython llvm) RUN(NAME generics_array_01 LABELS llvm) RUN(NAME test_statistics LABELS cpython llvm) diff --git a/integration_tests/generics_02.py b/integration_tests/generics_02.py new file mode 100644 index 0000000000..5e36813211 --- /dev/null +++ b/integration_tests/generics_02.py @@ -0,0 +1,13 @@ +from ltypes import TypeVar + +T = TypeVar('T') + +def swap(x: T, y: T): + temp: T + temp = x + x = y + y = temp + print(x) + print(y) + +swap(1,2) \ No newline at end of file diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index caa6a8f548..56ec6af347 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -1169,6 +1169,22 @@ static inline ASR::ttype_t* duplicate_type_without_dims(Allocator& al, const ASR return ASRUtils::TYPE(ASR::make_Integer_t(al, t->base.loc, tnew->m_kind, nullptr, 0)); } + case ASR::ttypeType::Real: { + ASR::Real_t* tnew = ASR::down_cast(t); + return ASRUtils::TYPE(ASR::make_Real_t(al, t->base.loc, + tnew->m_kind, nullptr, 0)); + } + case ASR::ttypeType::Character: { + ASR::Character_t* tnew = ASR::down_cast(t); + return ASRUtils::TYPE(ASR::make_Character_t(al, t->base.loc, + tnew->m_kind, tnew->m_len, tnew->m_len_expr, + nullptr, 0)); + } + case ASR::ttypeType::TypeParameter: { + ASR::TypeParameter_t* tp = ASR::down_cast(t); + return ASRUtils::TYPE(ASR::make_TypeParameter_t(al, t->base.loc, + tp->m_param, nullptr, 0)); + } default : throw LCompilersException("Not implemented " + std::to_string(t->type)); } } @@ -1387,6 +1403,19 @@ static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type) { } } +static inline ASR::ttype_t* get_type_parameter(ASR::ttype_t* t) { + switch (t->type) { + case ASR::ttypeType::TypeParameter: { + return t; + } + case ASR::ttypeType::List: { + ASR::List_t *tl = ASR::down_cast(t); + return get_type_parameter(tl->m_type); + } + default: throw LCompilersException("Cannot get type parameter from this type."); + } +} + class ReplaceArgVisitor: public ASR::BaseExprReplacer { private: diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 19942f59a9..61036d34c6 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1008,7 +1008,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor // Generate function prototypes for (auto &item : x.m_global_scope->get_scope()) { if (is_a(*item.second)) { - visit_Function(*ASR::down_cast(item.second)); + if (ASR::down_cast(item.second)->n_type_params == 0) { + visit_Function(*ASR::down_cast(item.second)); + } } } prototype_only = false; @@ -1028,7 +1030,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor // Then do all the procedures for (auto &item : x.m_global_scope->get_scope()) { if (is_a(*item.second)) { - visit_symbol(*item.second); + if (ASR::down_cast(item.second)->n_type_params == 0) { + visit_symbol(*item.second); + } } } diff --git a/src/libasr/pass/instantiate_template.cpp b/src/libasr/pass/instantiate_template.cpp index b916e2700b..952256bed8 100644 --- a/src/libasr/pass/instantiate_template.cpp +++ b/src/libasr/pass/instantiate_template.cpp @@ -11,23 +11,20 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator subs; - int new_function_num; + std::string new_func_name; FunctionInstantiator(Allocator &al, std::map subs, - SymbolTable *current_scope, int new_function_num): + SymbolTable *current_scope, std::string new_func_name): BaseExprStmtDuplicator(al), current_scope{current_scope}, subs{subs}, - new_function_num{new_function_num} + new_func_name{new_func_name} {} ASR::asr_t* instantiate_Function(ASR::Function_t &x) { SymbolTable *parent_scope = current_scope; current_scope = al.make_new(parent_scope); - std::string func_name = x.m_name; - func_name = "__lpython_generic_" + func_name + "_" + std::to_string(new_function_num); - Vec args; args.reserve(al, x.n_args); for (size_t i=0; i( - (ASR::down_cast(x.m_return_var))->m_v); - std::string return_var_name = return_var->m_name; - ASR::ttype_t *return_param_type = ASRUtils::expr_type(x.m_return_var); - ASR::ttype_t *return_type = ASR::is_a(*return_param_type) ? - subs[ASR::down_cast(return_param_type)->m_param] : return_param_type; - ASR::asr_t *new_return_var = ASR::make_Variable_t(al, return_var->base.base.loc, - current_scope, s2c(al, return_var_name), return_var->m_intent, nullptr, nullptr, - return_var->m_storage, return_type, return_var->m_abi, return_var->m_access, - return_var->m_presence, return_var->m_value_attr); - current_scope->add_symbol(return_var_name, ASR::down_cast(new_return_var)); - ASR::asr_t *new_return_var_ref = ASR::make_Var_t(al, x.base.base.loc, - current_scope->get_symbol(return_var_name)); + ASR::expr_t *new_return_var_ref = nullptr; + if (x.m_return_var != nullptr) { + ASR::Variable_t *return_var = ASR::down_cast( + (ASR::down_cast(x.m_return_var))->m_v); + std::string return_var_name = return_var->m_name; + ASR::ttype_t *return_param_type = ASRUtils::expr_type(x.m_return_var); + ASR::ttype_t *return_type = substitute_type(return_param_type); + ASR::asr_t *new_return_var = ASR::make_Variable_t(al, return_var->base.base.loc, + current_scope, s2c(al, return_var_name), return_var->m_intent, nullptr, nullptr, + return_var->m_storage, return_type, return_var->m_abi, return_var->m_access, + return_var->m_presence, return_var->m_value_attr); + current_scope->add_symbol(return_var_name, ASR::down_cast(new_return_var)); + new_return_var_ref = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, + current_scope->get_symbol(return_var_name))); + } // Rebuild the symbol table for (auto const &sym_pair: x.m_symtab->get_scope()) { @@ -104,16 +103,16 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator(result); - parent_scope->add_symbol(func_name, t); + parent_scope->add_symbol(new_func_name, t); current_scope = parent_scope; return result; @@ -322,8 +321,8 @@ class FunctionInstantiator : public ASR::BaseExprStmtDuplicator subs, - SymbolTable *current_scope, int new_function_num, ASR::Function_t &func) { - FunctionInstantiator tf(al, subs, current_scope, new_function_num); + SymbolTable *current_scope, std::string new_func_name, ASR::Function_t &func) { + FunctionInstantiator tf(al, subs, current_scope, new_func_name); ASR::asr_t *new_function = tf.instantiate_Function(func); return ASR::down_cast(new_function); } diff --git a/src/libasr/pass/instantiate_template.h b/src/libasr/pass/instantiate_template.h index 0c9d9130ae..f6752da200 100644 --- a/src/libasr/pass/instantiate_template.h +++ b/src/libasr/pass/instantiate_template.h @@ -6,7 +6,7 @@ namespace LFortran { ASR::symbol_t* pass_instantiate_generic_function(Allocator &al, std::map subs, - SymbolTable *current_scope, int new_function_num, ASR::Function_t &func); + SymbolTable *current_scope, std::string new_func_name, ASR::Function_t &func); } #endif // LFORTRAN_PASS_TEMPLATE_VISITOR_H \ No newline at end of file diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 9e55bae10b..8f70605531 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -714,10 +714,24 @@ class CommonVisitor : public AST::BaseVisitor { stemp = symtab->get_symbol(local_sym); } } - if (ASR::is_a(*s) && - ASR::down_cast(s)->m_return_var != nullptr) { + if (ASR::is_a(*s)) { ASR::Function_t *func = ASR::down_cast(s); - if (func->n_type_params == 0) { + if (func->n_type_params > 0) { + std::map subs; + for (size_t i=0; im_args[i]); + ASR::ttype_t *arg_type = ASRUtils::expr_type(args[i].m_value); + subs = check_type_substitution(subs, param_type, arg_type, loc); + } + + ASR::symbol_t *t = get_generic_function(subs, *func); + std::string new_call_name = call_name; + if (ASR::is_a(*t)) { + new_call_name = (ASR::down_cast(t))->m_name; + } + return make_call_helper(al, t, current_scope, args, new_call_name, loc); + } + if (ASR::down_cast(s)->m_return_var != nullptr) { ASR::ttype_t *a_type = nullptr; if( func->m_elemental && args.size() == 1 && ASRUtils::is_array(ASRUtils::expr_type(args[0].m_value)) ) { @@ -763,39 +777,25 @@ class CommonVisitor : public AST::BaseVisitor { return func_call_asr; } } else { - std::map subs; - for (size_t i=0; im_args[i]); - ASR::ttype_t *arg_type = ASRUtils::expr_type(args[i].m_value); - subs = check_type_substitution(subs, param_type, arg_type, loc); - } - - ASR::symbol_t *t = get_generic_function(subs, *func); - std::string new_call_name = call_name; - if (ASR::is_a(*t)) { - new_call_name = (ASR::down_cast(t))->m_name; + ASR::Function_t *func = ASR::down_cast(s); + if (args.size() != func->n_args) { + std::string fnd = std::to_string(args.size()); + std::string org = std::to_string(func->n_args); + diag.add(diag::Diagnostic( + "Number of arguments does not match in the function call", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("(found: '" + fnd + "', expected: '" + org + "')", + {loc}) + }) + ); + throw SemanticAbort(); } - return make_call_helper(al, t, current_scope, args, new_call_name, loc); - } - } else if (ASR::is_a(*s)) { - ASR::Function_t *func = ASR::down_cast(s); - if (args.size() != func->n_args) { - std::string fnd = std::to_string(args.size()); - std::string org = std::to_string(func->n_args); - diag.add(diag::Diagnostic( - "Number of arguments does not match in the function call", - diag::Level::Error, diag::Stage::Semantic, { - diag::Label("(found: '" + fnd + "', expected: '" + org + "')", - {loc}) - }) - ); - throw SemanticAbort(); + Vec args_new; + args_new.reserve(al, func->n_args); + visit_expr_list_with_cast(func->m_args, func->n_args, args_new, args); + return ASR::make_SubroutineCall_t(al, loc, stemp, + s_generic, args_new.p, args_new.size(), nullptr); } - Vec args_new; - args_new.reserve(al, func->n_args); - visit_expr_list_with_cast(func->m_args, func->n_args, args_new, args); - return ASR::make_SubroutineCall_t(al, loc, stemp, - s_generic, args_new.p, args_new.size(), nullptr); } else if(ASR::is_a(*s)) { Vec args_new; args_new.reserve(al, args.size()); @@ -888,8 +888,9 @@ class CommonVisitor : public AST::BaseVisitor { new_function_num = 0; } generic_func_nums[func_name] = new_function_num + 1; - generic_func_subs["__lpython_generic_" + func_name + "_" + std::to_string(new_function_num)] = subs; - t = pass_instantiate_generic_function(al, subs, current_scope, new_function_num, func); + std::string new_func_name = "__lpython_generic_" + func_name + "_" + std::to_string(new_function_num); + generic_func_subs[new_func_name] = subs; + t = pass_instantiate_generic_function(al, subs, current_scope, new_func_name, func); return t; } @@ -2270,7 +2271,8 @@ class SymbolTableVisitor : public CommonVisitor { current_procedure_abi_type = ASR::abiType::Source; bool current_procedure_interface = false; bool overload = false; - std::set ps; + Vec tps; + tps.reserve(al, x.m_args.n_args); bool vectorize = false; if (x.n_decorator_list > 0) { for(size_t i=0; i { ASR::ttype_t *arg_type = ast_expr_to_asr_type(x.base.base.loc, *x.m_args.m_args[i].m_annotation); // Set the function as generic if an argument is typed with a type parameter if (ASRUtils::is_generic(*arg_type)) { - std::string param_name = ASRUtils::get_parameter_name(arg_type); - ps.insert(param_name); + ASR::ttype_t *new_tt = ASRUtils::duplicate_type_without_dims(al, ASRUtils::get_type_parameter(arg_type)); + size_t current_size = tps.size(); + if (current_size == 0) { + tps.push_back(al, new_tt); + } else { + bool not_found = true; + for (size_t i = 0; i < current_size; i++) { + ASR::TypeParameter_t *added_tp = ASR::down_cast(tps.p[i]); + std::string new_param = ASR::down_cast(new_tt)->m_param; + std::string added_param = added_tp->m_param; + if (added_param.compare(new_param) == 0) { + not_found = false; + break; + } + } + if (not_found) { + tps.push_back(al, new_tt); + } + } } std::string arg_s = arg; @@ -2374,43 +2393,19 @@ class SymbolTableVisitor : public CommonVisitor { ASR::down_cast(return_var)); ASR::asr_t *return_var_ref = ASR::make_Var_t(al, x.base.base.loc, current_scope->get_symbol(return_var_name)); - if (ps.size() > 0) { - Vec type_params; - type_params.reserve(al, ps.size()); - for (auto &p: ps) { - std::string param = p; - ASR::ttype_t *type_p = ASRUtils::TYPE(ASR::make_TypeParameter_t(al, - x.base.base.loc, s2c(al, p), nullptr, 0)); - type_params.push_back(al, type_p); - } - tmp = ASR::make_Function_t( - al, x.base.base.loc, - /* a_symtab */ current_scope, - /* a_name */ s2c(al, sym_name), - /* a_args */ args.p, - /* n_args */ args.size(), - /* a_type_params */ type_params.p, - /* n_type_params */ type_params.size(), - /* a_body */ nullptr, - /* n_body */ 0, - /* a_return_var */ ASRUtils::EXPR(return_var_ref), - current_procedure_abi_type, - s_access, deftype, bindc_name, vectorize, false, false); - } else { - tmp = ASR::make_Function_t( - al, x.base.base.loc, - /* a_symtab */ current_scope, - /* a_name */ s2c(al, sym_name), - /* a_args */ args.p, - /* n_args */ args.size(), - /* a_type_params */ nullptr, - /* n_type_params */ 0, - /* a_body */ nullptr, - /* n_body */ 0, - /* a_return_var */ ASRUtils::EXPR(return_var_ref), - current_procedure_abi_type, - s_access, deftype, bindc_name, vectorize, false, false); - } + tmp = ASR::make_Function_t( + al, x.base.base.loc, + /* a_symtab */ current_scope, + /* a_name */ s2c(al, sym_name), + /* a_args */ args.p, + /* n_args */ args.size(), + /* a_type_params */ tps.p, + /* n_type_params */ tps.size(), + /* a_body */ nullptr, + /* n_body */ 0, + /* a_return_var */ ASRUtils::EXPR(return_var_ref), + current_procedure_abi_type, + s_access, deftype, bindc_name, vectorize, false, false); } else { throw SemanticError("Return variable must be an identifier (Name AST node) or an array (Subscript AST node)", x.m_returns->base.loc); @@ -2423,7 +2418,8 @@ class SymbolTableVisitor : public CommonVisitor { /* a_name */ s2c(al, sym_name), /* a_args */ args.p, /* n_args */ args.size(), - nullptr, 0, + /* a_type_params */ tps.p, + /* n_type_params */ tps.size(), /* a_body */ nullptr, /* n_body */ 0, nullptr, diff --git a/tests/reference/asr-generics_02-e2ea5c9.json b/tests/reference/asr-generics_02-e2ea5c9.json new file mode 100644 index 0000000000..311573d1be --- /dev/null +++ b/tests/reference/asr-generics_02-e2ea5c9.json @@ -0,0 +1,13 @@ +{ + "basename": "asr-generics_02-e2ea5c9", + "cmd": "lpython --show-asr --no-color {infile} -o {outfile}", + "infile": "tests/../integration_tests/generics_02.py", + "infile_hash": "6f748cbf059da328ff092c08ad0063e639ba460eb481fadfadeebc51", + "outfile": null, + "outfile_hash": null, + "stdout": "asr-generics_02-e2ea5c9.stdout", + "stdout_hash": "5e109b8427357ce9883ef4927ac34a69330c4948dde307af0a92de9c", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/asr-generics_02-e2ea5c9.stdout b/tests/reference/asr-generics_02-e2ea5c9.stdout new file mode 100644 index 0000000000..c669a0f382 --- /dev/null +++ b/tests/reference/asr-generics_02-e2ea5c9.stdout @@ -0,0 +1 @@ +(TranslationUnit (SymbolTable 1 {T: (Variable 1 T Local () () Default (TypeParameter T []) Source Public Required .false.), __lpython_generic_swap_0: (Function (SymbolTable 3 {temp: (Variable 3 temp Local () () Default (Integer 4 []) Source Public Required .false.), x: (Variable 3 x In () () Default (Integer 4 []) Source Public Required .false.), y: (Variable 3 y In () () Default (Integer 4 []) Source Public Required .false.)}) __lpython_generic_swap_0 [(Var 3 x) (Var 3 y)] [] [(= (Var 3 temp) (Var 3 x) ()) (= (Var 3 x) (Var 3 y) ()) (= (Var 3 y) (Var 3 temp) ()) (Print () [(Var 3 x)] () ()) (Print () [(Var 3 y)] () ())] () Source Public Implementation () .false. .false. .false.), _lpython_main_program: (Function (SymbolTable 5 {}) _lpython_main_program [] [] [(SubroutineCall 1 __lpython_generic_swap_0 () [((IntegerConstant 1 (Integer 4 []))) ((IntegerConstant 2 (Integer 4 [])))] ())] () Source Public Implementation () .false. .false. .false.), main_program: (Program (SymbolTable 4 {}) main_program [] [(SubroutineCall 1 _lpython_main_program () [] ())]), swap: (Function (SymbolTable 2 {temp: (Variable 2 temp Local () () Default (TypeParameter T []) Source Public Required .false.), x: (Variable 2 x In () () Default (TypeParameter T []) Source Public Required .false.), y: (Variable 2 y In () () Default (TypeParameter T []) Source Public Required .false.)}) swap [(Var 2 x) (Var 2 y)] [(TypeParameter T [])] [(= (Var 2 temp) (Var 2 x) ()) (= (Var 2 x) (Var 2 y) ()) (= (Var 2 y) (Var 2 temp) ()) (Print () [(Var 2 x)] () ()) (Print () [(Var 2 y)] () ())] () Source Public Implementation () .false. .false. .false.)}) []) diff --git a/tests/reference/cpp-expr3-9c516d4.json b/tests/reference/cpp-expr3-9c516d4.json deleted file mode 100644 index fa9e9258d6..0000000000 --- a/tests/reference/cpp-expr3-9c516d4.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "basename": "cpp-expr3-9c516d4", - "cmd": "lpython --no-color --show-cpp {infile}", - "infile": "tests/expr3.py", - "infile_hash": "4fbbd9ddebefcc9afdd6fdc17e16313fafc5b3e214595d6ad62c10cb", - "outfile": null, - "outfile_hash": null, - "stdout": null, - "stdout_hash": null, - "stderr": "cpp-expr3-9c516d4.stderr", - "stderr_hash": "715f5ea03b41d70718afea8c302485a5cf9e2602b87293bebd147d43", - "returncode": 2 -} \ No newline at end of file diff --git a/tests/tests.toml b/tests/tests.toml index 6d97422ec3..4336541a13 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -377,6 +377,10 @@ cpp = true filename = "../integration_tests/generics_01.py" asr = true +[[test]] +filename = "../integration_tests/generics_02.py" +asr = true + [[test]] filename = "../integration_tests/generics_array_01.py" asr = true