Skip to content

Commit 5323978

Browse files
committed
Use IntrinsicFunction infrastructure for list.index
1 parent 21f3db7 commit 5323978

9 files changed

+148
-40
lines changed

src/libasr/ASR.asdl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ expr
255255
| ListConcat(expr left, expr right, ttype type, expr? value)
256256
| ListCompare(expr left, cmpop op, expr right, ttype type, expr? value)
257257
| ListCount(expr arg, expr ele, ttype type, expr? value)
258-
| ListIndex(expr arg, expr ele, ttype type, expr? value)
259258

260259
| SetConstant(expr* elements, ttype type)
261260
| SetLen(expr arg, ttype type, expr? value)

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
#include <libasr/codegen/llvm_utils.h>
5050
#include <libasr/codegen/llvm_array_utils.h>
5151

52+
#include <libasr/pass/intrinsic_function_registry.h>
53+
5254
#if LLVM_VERSION_MAJOR >= 11
5355
# define FIXED_VECTOR_TYPE llvm::FixedVectorType
5456
#else
@@ -1962,20 +1964,45 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
19621964
tmp = list_api->count(plist, item, asr_el_type, *module);
19631965
}
19641966

1965-
void visit_ListIndex(const ASR::ListIndex_t& x) {
1966-
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_arg));
1967+
void generate_ListIndex(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
1968+
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
19671969
int64_t ptr_loads_copy = ptr_loads;
19681970
ptr_loads = 0;
1969-
this->visit_expr(*x.m_arg);
1971+
this->visit_expr(*m_arg);
19701972
llvm::Value* plist = tmp;
19711973

19721974
ptr_loads = !LLVM::is_llvm_struct(asr_el_type);
1973-
this->visit_expr_wrapper(x.m_ele, true);
1975+
this->visit_expr_wrapper(m_ele, true);
19741976
ptr_loads = ptr_loads_copy;
19751977
llvm::Value *item = tmp;
19761978
tmp = list_api->index(plist, item, asr_el_type, *module);
19771979
}
19781980

1981+
void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
1982+
switch (static_cast<ASRUtils::IntrinsicFunctions>(x.m_intrinsic_id)) {
1983+
case ASRUtils::IntrinsicFunctions::ListIndex: {
1984+
switch (x.m_overload_id) {
1985+
case 0: {
1986+
ASR::expr_t* m_arg = x.m_args[0];
1987+
ASR::expr_t* m_ele = x.m_args[1];
1988+
generate_ListIndex(m_arg, m_ele);
1989+
break ;
1990+
}
1991+
default: {
1992+
throw CodeGenError("list.index only accepts one argument",
1993+
x.base.base.loc);
1994+
}
1995+
}
1996+
break ;
1997+
}
1998+
default: {
1999+
throw CodeGenError( ASRUtils::IntrinsicFunctionRegistry::
2000+
get_intrinsic_function_name(x.m_intrinsic_id) +
2001+
" is not implemented by LLVM backend.", x.base.base.loc);
2002+
}
2003+
}
2004+
}
2005+
19792006
void visit_ListClear(const ASR::ListClear_t& x) {
19802007
int64_t ptr_loads_copy = ptr_loads;
19812008
ptr_loads = 0;

src/libasr/pass/intrinsic_function.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,30 +38,31 @@ class ReplaceIntrinsicFunction: public ASR::BaseExprReplacer<ReplaceIntrinsicFun
3838

3939

4040
void replace_IntrinsicFunction(ASR::IntrinsicFunction_t* x) {
41-
LCOMPILERS_ASSERT(x->n_args == 1)
4241
Vec<ASR::call_arg_t> new_args;
4342
// Replace any IntrinsicFunctions in the argument first:
4443
{
45-
ASR::expr_t** current_expr_copy_ = current_expr;
46-
current_expr = &(x->m_args[0]);
47-
replace_expr(x->m_args[0]);
4844
new_args.reserve(al, x->n_args);
49-
ASR::call_arg_t arg0;
50-
arg0.m_value = *current_expr; // Use the converted arg
51-
new_args.push_back(al, arg0);
52-
current_expr = current_expr_copy_;
45+
for( size_t i = 0; i < x->n_args; i++ ) {
46+
ASR::expr_t** current_expr_copy_ = current_expr;
47+
current_expr = &(x->m_args[i]);
48+
replace_expr(x->m_args[i]);
49+
ASR::call_arg_t arg0;
50+
arg0.m_value = *current_expr; // Use the converted arg
51+
new_args.push_back(al, arg0);
52+
current_expr = current_expr_copy_;
53+
}
5354
}
5455
// TODO: currently we always instantiate a new function.
5556
// Rather we should reuse the old instantiation if it has
5657
// exactly the same arguments. For that we could use the
5758
// overload_id, and uniquely encode the argument types.
5859
// We could maintain a mapping of type -> id and look it up.
59-
if( !ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(x->m_intrinsic_id) ) {
60-
throw LCompilersException("Intrinsic function not implemented");
61-
}
6260

6361
ASRUtils::impl_function instantiate_function =
6462
ASRUtils::IntrinsicFunctionRegistry::get_instantiate_function(x->m_intrinsic_id);
63+
if( instantiate_function == nullptr ) {
64+
return ;
65+
}
6566
Vec<ASR::ttype_t*> arg_types;
6667
arg_types.reserve(al, x->n_args);
6768
for( size_t i = 0; i < x->n_args; i++ ) {

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ enum class IntrinsicFunctions : int64_t {
4343
Gamma,
4444
LogGamma,
4545
Abs,
46+
47+
ListIndex,
4648
// ...
4749
};
4850

@@ -453,6 +455,54 @@ namespace Abs {
453455

454456
} // namespace Abs
455457

458+
namespace ListIndex {
459+
460+
static inline ASR::expr_t *eval_list_index(Allocator &/*al*/,
461+
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
462+
// TODO: To be implemented for ListConstant expression
463+
return nullptr;
464+
}
465+
466+
static inline ASR::asr_t* create_ListIndex(Allocator& al, const Location& loc,
467+
Vec<ASR::expr_t*>& args,
468+
const std::function<void (const std::string &, const Location &)> err) {
469+
if (args.size() != 2) {
470+
// Support start and end arguments by overloading ListIndex
471+
// intrinsic. We need 3 overload IDs,
472+
// 0 - only list and element
473+
// 1 - list, element and start
474+
// 2 - list, element, start and end
475+
// list, element and end case is not possible as list.index
476+
// doesn't accept keyword arguments
477+
err("For now index() takes exactly one argument", loc);
478+
}
479+
480+
ASR::expr_t* list_expr = args[0];
481+
ASR::ttype_t *type = ASRUtils::expr_type(list_expr);
482+
ASR::ttype_t *list_type = ASR::down_cast<ASR::List_t>(type)->m_type;
483+
ASR::ttype_t *ele_type = ASRUtils::expr_type(args[1]);
484+
if (!ASRUtils::check_equal_type(ele_type, list_type)) {
485+
std::string fnd = ASRUtils::type_to_str_python(ele_type);
486+
std::string org = ASRUtils::type_to_str_python(list_type);
487+
err(
488+
"Type mismatch in 'index', the types must be compatible "
489+
"(found: '" + fnd + "', expected: '" + org + "')", loc);
490+
}
491+
Vec<ASR::expr_t*> arg_values;
492+
arg_values.reserve(al, args.size());
493+
for( size_t i = 0; i < args.size(); i++ ) {
494+
arg_values.push_back(al, ASRUtils::expr_value(args[i]));
495+
}
496+
ASR::expr_t* compile_time_value = eval_list_index(al, loc, arg_values);
497+
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc,
498+
4, nullptr, 0));
499+
return ASR::make_IntrinsicFunction_t(al, loc,
500+
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::ListIndex),
501+
args.p, args.size(), 0, to_type, compile_time_value);
502+
}
503+
504+
} // namespace ListIndex
505+
456506

457507
namespace IntrinsicFunctionRegistry {
458508

@@ -468,13 +518,28 @@ namespace IntrinsicFunctionRegistry {
468518
&Abs::instantiate_Abs}
469519
};
470520

521+
static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
522+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::LogGamma),
523+
"log_gamma"},
524+
525+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Sin),
526+
"sin"},
527+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Cos),
528+
"cos"},
529+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Abs),
530+
"abs"},
531+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::ListIndex),
532+
"list.index"}
533+
};
534+
471535
static const std::map<std::string,
472536
std::pair<create_intrinsic_function,
473537
eval_intrinsic_function>>& intrinsic_function_by_name_db = {
474538
{"log_gamma", {&LogGamma::create_LogGamma, &LogGamma::eval_log_gamma}},
475539
{"sin", {&Sin::create_Sin, &Sin::eval_Sin}},
476540
{"cos", {&Cos::create_Cos, &Cos::eval_Cos}},
477541
{"abs", {&Abs::create_Abs, &Abs::eval_Abs}},
542+
{"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}},
478543
};
479544

480545
static inline bool is_intrinsic_function(const std::string& name) {
@@ -490,9 +555,20 @@ namespace IntrinsicFunctionRegistry {
490555
}
491556

492557
static inline impl_function get_instantiate_function(int64_t id) {
558+
if( intrinsic_function_by_id_db.find(id) == intrinsic_function_by_id_db.end() ) {
559+
return nullptr;
560+
}
493561
return intrinsic_function_by_id_db.at(id);
494562
}
495563

564+
static inline std::string get_intrinsic_function_name(int64_t id) {
565+
if( intrinsic_function_id_to_name.find(id) == intrinsic_function_id_to_name.end() ) {
566+
throw LCompilersException("IntrinsicFunction with ID " + std::to_string(id) +
567+
" has no name registered for it");
568+
}
569+
return intrinsic_function_id_to_name.at(id);
570+
}
571+
496572
} // namespace IntrinsicFunctionRegistry
497573

498574
#define INTRINSIC_NAME_CASE(X) \

src/lpython/semantics/python_attribute_eval.h

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <libasr/string_utils.h>
77
#include <lpython/utils.h>
88
#include <lpython/semantics/semantic_exception.h>
9+
#include <libasr/pass/intrinsic_function_registry.h>
910

1011
namespace LCompilers::LPython {
1112

@@ -151,29 +152,17 @@ struct AttributeHandler {
151152
}
152153

153154
static ASR::asr_t* eval_list_index(ASR::expr_t *s, Allocator &al, const Location &loc,
154-
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
155-
if (args.size() != 1) {
156-
throw SemanticError("index() takes exactly one argument",
157-
loc);
158-
}
159-
ASR::ttype_t *type = ASRUtils::expr_type(s);
160-
ASR::ttype_t *list_type = ASR::down_cast<ASR::List_t>(type)->m_type;
161-
ASR::ttype_t *ele_type = ASRUtils::expr_type(args[0]);
162-
if (!ASRUtils::check_equal_type(ele_type, list_type)) {
163-
std::string fnd = ASRUtils::type_to_str_python(ele_type);
164-
std::string org = ASRUtils::type_to_str_python(list_type);
165-
diag.add(diag::Diagnostic(
166-
"Type mismatch in 'index', the types must be compatible",
167-
diag::Level::Error, diag::Stage::Semantic, {
168-
diag::Label("type mismatch (found: '" + fnd + "', expected: '" + org + "')",
169-
{args[0]->base.loc})
170-
})
171-
);
172-
throw SemanticAbort();
155+
Vec<ASR::expr_t*> &args, diag::Diagnostics &/*diag*/) {
156+
Vec<ASR::expr_t*> args_with_list;
157+
args_with_list.reserve(al, args.size() + 1);
158+
args_with_list.push_back(al, s);
159+
for(size_t i = 0; i < args.size(); i++) {
160+
args_with_list.push_back(al, args[i]);
173161
}
174-
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc,
175-
4, nullptr, 0));
176-
return make_ListIndex_t(al, loc, s, args[0], to_type, nullptr);
162+
ASRUtils::create_intrinsic_function create_function =
163+
ASRUtils::IntrinsicFunctionRegistry::get_create_function("list.index");
164+
return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc)
165+
{ throw SemanticError(msg, loc); });
177166
}
178167

179168
static ASR::asr_t* eval_list_insert(ASR::expr_t *s, Allocator &al, const Location &loc,

tests/errors/test_list_index.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ def test_list_index_error():
44
a: list[i32]
55
a = [1, 2, 3]
66
# a.index(1.0) # type mismatch
7-
print(a.index(0)) # no error?
7+
print(a.index(0)) # no error?
8+
9+
test_list_index_error()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "runtime-test_list_index-0483808",
3+
"cmd": "lpython {infile}",
4+
"infile": "tests/errors/test_list_index.py",
5+
"infile_hash": "991dc5eddf2579b7c6b3bc31bb07656cec73bdf91c3196a761985682",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "runtime-test_list_index-0483808.stderr",
11+
"stderr_hash": "dd3d49b5f2f97ed8f1d27cd73ebca7a8740483660dd4ae702e2048b2",
12+
"returncode": 1
13+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ValueError: The list does not contain the element: 0

tests/tests.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ asr = true
678678

679679
[[test]]
680680
filename = "errors/test_list_index.py"
681-
asr = true
681+
run = true
682682

683683
[[test]]
684684
filename = "errors/test_list1.py"

0 commit comments

Comments
 (0)