Skip to content

Commit f6bcbd2

Browse files
Add list.index using IntrinsicFunction API (#1703)
Co-authored-by: Gagandeep Singh <[email protected]>
1 parent 623a401 commit f6bcbd2

14 files changed

+235
-11
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ RUN(NAME test_list_09 LABELS cpython llvm c)
281281
RUN(NAME test_list_10 LABELS cpython llvm c)
282282
RUN(NAME test_list_section LABELS cpython llvm c)
283283
RUN(NAME test_list_count LABELS cpython llvm)
284+
RUN(NAME test_list_index LABELS cpython llvm)
284285
RUN(NAME test_tuple_01 LABELS cpython llvm c)
285286
RUN(NAME test_tuple_02 LABELS cpython llvm c)
286287
RUN(NAME test_tuple_03 LABELS cpython llvm c)

integration_tests/test_list_index.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from lpython import i32, f64
2+
3+
def test_list_index():
4+
i: i32
5+
x: list[i32] = []
6+
y: list[str] = []
7+
z: list[tuple[i32, str, f64]] = []
8+
9+
for i in range(-5, 0):
10+
x.append(i)
11+
assert x.index(i) == len(x)-1
12+
x.append(i)
13+
assert x.index(i) == len(x)-2
14+
x.remove(i)
15+
assert x.index(i) == len(x)-1
16+
17+
assert x == [-5, -4, -3, -2, -1]
18+
19+
for i in range(-5, 0):
20+
x.append(i)
21+
assert x.index(i) == 0
22+
x.remove(i)
23+
assert x.index(i) == len(x)-1
24+
25+
# str
26+
y = ['a', 'abc', 'a', 'b', 'abc']
27+
assert y.index('a') == 0
28+
assert y.index('abc') == 1
29+
30+
# tuple, float
31+
z = [(i32(1), 'a', f64(2.01)), (i32(-1), 'b', f64(2)), (i32(1), 'a', f64(2.02))]
32+
assert z.index((i32(1), 'a', f64(2.01))) == 0
33+
z.insert(0, (i32(1), 'a', f64(2)))
34+
assert z.index((i32(1), 'a', f64(2.00))) == 0
35+
z.append((i32(1), 'a', f64(2.00)))
36+
assert z.index((i32(1), 'a', f64(2))) == 0
37+
z.remove((i32(1), 'a', f64(2)))
38+
assert z.index((i32(1), 'a', f64(2.00))) == 3
39+
40+
test_list_index()

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
#include <libasr/codegen/llvm_utils.h>
4949
#include <libasr/codegen/llvm_array_utils.h>
5050

51+
#include <libasr/pass/intrinsic_function_registry.h>
52+
5153
#if LLVM_VERSION_MAJOR >= 11
5254
# define FIXED_VECTOR_TYPE llvm::FixedVectorType
5355
#else
@@ -1929,6 +1931,45 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
19291931
tmp = list_api->count(plist, item, asr_el_type, *module);
19301932
}
19311933

1934+
void generate_ListIndex(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
1935+
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
1936+
int64_t ptr_loads_copy = ptr_loads;
1937+
ptr_loads = 0;
1938+
this->visit_expr(*m_arg);
1939+
llvm::Value* plist = tmp;
1940+
1941+
ptr_loads = !LLVM::is_llvm_struct(asr_el_type);
1942+
this->visit_expr_wrapper(m_ele, true);
1943+
ptr_loads = ptr_loads_copy;
1944+
llvm::Value *item = tmp;
1945+
tmp = list_api->index(plist, item, asr_el_type, *module);
1946+
}
1947+
1948+
void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
1949+
switch (static_cast<ASRUtils::IntrinsicFunctions>(x.m_intrinsic_id)) {
1950+
case ASRUtils::IntrinsicFunctions::ListIndex: {
1951+
switch (x.m_overload_id) {
1952+
case 0: {
1953+
ASR::expr_t* m_arg = x.m_args[0];
1954+
ASR::expr_t* m_ele = x.m_args[1];
1955+
generate_ListIndex(m_arg, m_ele);
1956+
break ;
1957+
}
1958+
default: {
1959+
throw CodeGenError("list.index only accepts one argument",
1960+
x.base.base.loc);
1961+
}
1962+
}
1963+
break ;
1964+
}
1965+
default: {
1966+
throw CodeGenError( ASRUtils::IntrinsicFunctionRegistry::
1967+
get_intrinsic_function_name(x.m_intrinsic_id) +
1968+
" is not implemented by LLVM backend.", x.base.base.loc);
1969+
}
1970+
}
1971+
}
1972+
19321973
void visit_ListClear(const ASR::ListClear_t& x) {
19331974
int64_t ptr_loads_copy = ptr_loads;
19341975
ptr_loads = 0;

src/libasr/codegen/llvm_utils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,6 +2498,11 @@ namespace LCompilers {
24982498
return LLVM::CreateLoad(*builder, i);
24992499
}
25002500

2501+
llvm::Value* LLVMList::index(llvm::Value* list, llvm::Value* item,
2502+
ASR::ttype_t* item_type, llvm::Module& module) {
2503+
return LLVMList::find_item_position(list, item, item_type, module);
2504+
}
2505+
25012506
llvm::Value* LLVMList::count(llvm::Value* list, llvm::Value* item,
25022507
ASR::ttype_t* item_type, llvm::Module& module) {
25032508
llvm::Type* pos_type = llvm::Type::getInt32Ty(context);

src/libasr/codegen/llvm_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ namespace LCompilers {
233233
llvm::Value* item, ASR::ttype_t* item_type,
234234
llvm::Module& module);
235235

236+
llvm::Value* index(llvm::Value* list, llvm::Value* item,
237+
ASR::ttype_t* item_type, llvm::Module& module);
238+
236239
llvm::Value* count(llvm::Value* list, llvm::Value* item,
237240
ASR::ttype_t* item_type, llvm::Module& module);
238241

src/libasr/pass/intrinsic_function.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,30 +38,31 @@ class ReplaceIntrinsicFunction: public ASR::BaseExprReplacer<ReplaceIntrinsicFun
3838

3939

4040
void replace_IntrinsicFunction(ASR::IntrinsicFunction_t* x) {
41-
LCOMPILERS_ASSERT(x->n_args == 1)
4241
Vec<ASR::call_arg_t> new_args;
4342
// Replace any IntrinsicFunctions in the argument first:
4443
{
45-
ASR::expr_t** current_expr_copy_ = current_expr;
46-
current_expr = &(x->m_args[0]);
47-
replace_expr(x->m_args[0]);
4844
new_args.reserve(al, x->n_args);
49-
ASR::call_arg_t arg0;
50-
arg0.m_value = *current_expr; // Use the converted arg
51-
new_args.push_back(al, arg0);
52-
current_expr = current_expr_copy_;
45+
for( size_t i = 0; i < x->n_args; i++ ) {
46+
ASR::expr_t** current_expr_copy_ = current_expr;
47+
current_expr = &(x->m_args[i]);
48+
replace_expr(x->m_args[i]);
49+
ASR::call_arg_t arg0;
50+
arg0.m_value = *current_expr; // Use the converted arg
51+
new_args.push_back(al, arg0);
52+
current_expr = current_expr_copy_;
53+
}
5354
}
5455
// TODO: currently we always instantiate a new function.
5556
// Rather we should reuse the old instantiation if it has
5657
// exactly the same arguments. For that we could use the
5758
// overload_id, and uniquely encode the argument types.
5859
// We could maintain a mapping of type -> id and look it up.
59-
if( !ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(x->m_intrinsic_id) ) {
60-
throw LCompilersException("Intrinsic function not implemented");
61-
}
6260

6361
ASRUtils::impl_function instantiate_function =
6462
ASRUtils::IntrinsicFunctionRegistry::get_instantiate_function(x->m_intrinsic_id);
63+
if( instantiate_function == nullptr ) {
64+
return ;
65+
}
6566
Vec<ASR::ttype_t*> arg_types;
6667
arg_types.reserve(al, x->n_args);
6768
for( size_t i = 0; i < x->n_args; i++ ) {

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ enum class IntrinsicFunctions : int64_t {
4343
Gamma,
4444
LogGamma,
4545
Abs,
46+
47+
ListIndex,
4648
// ...
4749
};
4850

@@ -453,6 +455,54 @@ namespace Abs {
453455

454456
} // namespace Abs
455457

458+
namespace ListIndex {
459+
460+
static inline ASR::expr_t *eval_list_index(Allocator &/*al*/,
461+
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
462+
// TODO: To be implemented for ListConstant expression
463+
return nullptr;
464+
}
465+
466+
static inline ASR::asr_t* create_ListIndex(Allocator& al, const Location& loc,
467+
Vec<ASR::expr_t*>& args,
468+
const std::function<void (const std::string &, const Location &)> err) {
469+
if (args.size() != 2) {
470+
// Support start and end arguments by overloading ListIndex
471+
// intrinsic. We need 3 overload IDs,
472+
// 0 - only list and element
473+
// 1 - list, element and start
474+
// 2 - list, element, start and end
475+
// list, element and end case is not possible as list.index
476+
// doesn't accept keyword arguments
477+
err("For now index() takes exactly one argument", loc);
478+
}
479+
480+
ASR::expr_t* list_expr = args[0];
481+
ASR::ttype_t *type = ASRUtils::expr_type(list_expr);
482+
ASR::ttype_t *list_type = ASR::down_cast<ASR::List_t>(type)->m_type;
483+
ASR::ttype_t *ele_type = ASRUtils::expr_type(args[1]);
484+
if (!ASRUtils::check_equal_type(ele_type, list_type)) {
485+
std::string fnd = ASRUtils::get_type_code(ele_type);
486+
std::string org = ASRUtils::get_type_code(list_type);
487+
err(
488+
"Type mismatch in 'index', the types must be compatible "
489+
"(found: '" + fnd + "', expected: '" + org + "')", loc);
490+
}
491+
Vec<ASR::expr_t*> arg_values;
492+
arg_values.reserve(al, args.size());
493+
for( size_t i = 0; i < args.size(); i++ ) {
494+
arg_values.push_back(al, ASRUtils::expr_value(args[i]));
495+
}
496+
ASR::expr_t* compile_time_value = eval_list_index(al, loc, arg_values);
497+
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc,
498+
4, nullptr, 0));
499+
return ASR::make_IntrinsicFunction_t(al, loc,
500+
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::ListIndex),
501+
args.p, args.size(), 0, to_type, compile_time_value);
502+
}
503+
504+
} // namespace ListIndex
505+
456506

457507
namespace IntrinsicFunctionRegistry {
458508

@@ -468,13 +518,28 @@ namespace IntrinsicFunctionRegistry {
468518
&Abs::instantiate_Abs}
469519
};
470520

521+
static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
522+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::LogGamma),
523+
"log_gamma"},
524+
525+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Sin),
526+
"sin"},
527+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Cos),
528+
"cos"},
529+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Abs),
530+
"abs"},
531+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::ListIndex),
532+
"list.index"}
533+
};
534+
471535
static const std::map<std::string,
472536
std::pair<create_intrinsic_function,
473537
eval_intrinsic_function>>& intrinsic_function_by_name_db = {
474538
{"log_gamma", {&LogGamma::create_LogGamma, &LogGamma::eval_log_gamma}},
475539
{"sin", {&Sin::create_Sin, &Sin::eval_Sin}},
476540
{"cos", {&Cos::create_Cos, &Cos::eval_Cos}},
477541
{"abs", {&Abs::create_Abs, &Abs::eval_Abs}},
542+
{"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}},
478543
};
479544

480545
static inline bool is_intrinsic_function(const std::string& name) {
@@ -490,9 +555,20 @@ namespace IntrinsicFunctionRegistry {
490555
}
491556

492557
static inline impl_function get_instantiate_function(int64_t id) {
558+
if( intrinsic_function_by_id_db.find(id) == intrinsic_function_by_id_db.end() ) {
559+
return nullptr;
560+
}
493561
return intrinsic_function_by_id_db.at(id);
494562
}
495563

564+
static inline std::string get_intrinsic_function_name(int64_t id) {
565+
if( intrinsic_function_id_to_name.find(id) == intrinsic_function_id_to_name.end() ) {
566+
throw LCompilersException("IntrinsicFunction with ID " + std::to_string(id) +
567+
" has no name registered for it");
568+
}
569+
return intrinsic_function_id_to_name.at(id);
570+
}
571+
496572
} // namespace IntrinsicFunctionRegistry
497573

498574
#define INTRINSIC_NAME_CASE(X) \

src/lpython/semantics/python_attribute_eval.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <libasr/string_utils.h>
77
#include <lpython/utils.h>
88
#include <lpython/semantics/semantic_exception.h>
9+
#include <libasr/pass/intrinsic_function_registry.h>
910

1011
namespace LCompilers::LPython {
1112

@@ -22,6 +23,7 @@ struct AttributeHandler {
2223
{"list@append", &eval_list_append},
2324
{"list@remove", &eval_list_remove},
2425
{"list@count", &eval_list_count},
26+
{"list@index", &eval_list_index},
2527
{"list@clear", &eval_list_clear},
2628
{"list@insert", &eval_list_insert},
2729
{"list@pop", &eval_list_pop},
@@ -149,6 +151,20 @@ struct AttributeHandler {
149151
return make_ListCount_t(al, loc, s, args[0], to_type, nullptr);
150152
}
151153

154+
static ASR::asr_t* eval_list_index(ASR::expr_t *s, Allocator &al, const Location &loc,
155+
Vec<ASR::expr_t*> &args, diag::Diagnostics &/*diag*/) {
156+
Vec<ASR::expr_t*> args_with_list;
157+
args_with_list.reserve(al, args.size() + 1);
158+
args_with_list.push_back(al, s);
159+
for(size_t i = 0; i < args.size(); i++) {
160+
args_with_list.push_back(al, args[i]);
161+
}
162+
ASRUtils::create_intrinsic_function create_function =
163+
ASRUtils::IntrinsicFunctionRegistry::get_create_function("list.index");
164+
return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc)
165+
{ throw SemanticError(msg, loc); });
166+
}
167+
152168
static ASR::asr_t* eval_list_insert(ASR::expr_t *s, Allocator &al, const Location &loc,
153169
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
154170
if (args.size() != 2) {

tests/errors/test_list_index.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from lpython import i32
2+
3+
def test_list_index_error():
4+
a: list[i32]
5+
a = [1, 2, 3]
6+
# a.index(1.0) # type mismatch
7+
print(a.index(0)) # no error?
8+
9+
test_list_index_error()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-test_list_index-6b8f30b",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/errors/test_list_index.py",
5+
"infile_hash": "9c629bc9805dde8f11fc465dc8b35f2c16aef28634c7f105b2e5cfa3",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": "asr-test_list_index-6b8f30b.stdout",
9+
"stdout_hash": "d91e269b3ba3053cbecf4e08960e027063d78db443c1bc6440cdf3c6",
10+
"stderr": null,
11+
"stderr_hash": null,
12+
"returncode": 0
13+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
(TranslationUnit (SymbolTable 1 {main_program: (Program (SymbolTable 3 {}) main_program [] []), test_list_index_error: (Function (SymbolTable 2 {a: (Variable 2 a [] Local () () Default (List (Integer 4 [])) Source Public Required .false.)}) test_list_index_error (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [] [(= (Var 2 a) (ListConstant [(IntegerConstant 1 (Integer 4 [])) (IntegerConstant 2 (Integer 4 [])) (IntegerConstant 3 (Integer 4 []))] (List (Integer 4 []))) ()) (Print () [(ListIndex (Var 2 a) (IntegerConstant 0 (Integer 4 [])) (Integer 4 []) ())] () ())] () Public .false. .false.)}) [])
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "runtime-test_list_index-0483808",
3+
"cmd": "lpython {infile}",
4+
"infile": "tests/errors/test_list_index.py",
5+
"infile_hash": "991dc5eddf2579b7c6b3bc31bb07656cec73bdf91c3196a761985682",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "runtime-test_list_index-0483808.stderr",
11+
"stderr_hash": "dd3d49b5f2f97ed8f1d27cd73ebca7a8740483660dd4ae702e2048b2",
12+
"returncode": 1
13+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ValueError: The list does not contain the element: 0

tests/tests.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,10 @@ asr = true
676676
filename = "errors/test_list_count.py"
677677
asr = true
678678

679+
[[test]]
680+
filename = "errors/test_list_index.py"
681+
run = true
682+
679683
[[test]]
680684
filename = "errors/test_list1.py"
681685
asr = true

0 commit comments

Comments
 (0)