Skip to content

Commit d489a65

Browse files
Add list.index (WIP)
1 parent ae6c547 commit d489a65

11 files changed

+116
-0
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ RUN(NAME test_list_09 LABELS cpython llvm c)
281281
RUN(NAME test_list_10 LABELS cpython llvm c)
282282
RUN(NAME test_list_section LABELS cpython llvm c)
283283
RUN(NAME test_list_count LABELS cpython llvm)
284+
RUN(NAME test_list_index LABELS cpython llvm)
284285
RUN(NAME test_tuple_01 LABELS cpython llvm c)
285286
RUN(NAME test_tuple_02 LABELS cpython llvm c)
286287
RUN(NAME test_tuple_03 LABELS cpython llvm c)

integration_tests/test_list_index.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from lpython import i32, f64
2+
3+
def test_list_index():
4+
i: i32
5+
x: list[i32] = []
6+
y: list[str] = []
7+
z: list[tuple[i32, str, f64]] = []
8+
9+
for i in range(-5, 0):
10+
x.append(i)
11+
assert x.index(i) == len(x)-1
12+
x.append(i)
13+
assert x.index(i) == len(x)-2
14+
x.remove(i)
15+
assert x.index(i) == len(x)-1
16+
17+
assert x == [-5, -4, -3, -2, -1]
18+
19+
for i in range(-5, 0):
20+
x.append(i)
21+
assert x.index(i) == 0
22+
x.remove(i)
23+
assert x.index(i) == len(x)-1
24+
25+
# str
26+
y = ['a', 'abc', 'a', 'b', 'abc']
27+
assert y.index('a') == 0
28+
assert y.index('abc') == 1
29+
30+
# tuple, float
31+
z = [(i32(1), 'a', f64(2.01)), (i32(-1), 'b', f64(2)), (i32(1), 'a', f64(2.02))]
32+
assert z.index((i32(1), 'a', f64(2.01))) == 0
33+
z.insert(0, (i32(1), 'a', f64(2)))
34+
assert z.index((i32(1), 'a', f64(2.00))) == 0
35+
z.append((i32(1), 'a', f64(2.00)))
36+
assert z.index((i32(1), 'a', f64(2))) == 0
37+
z.remove((i32(1), 'a', f64(2)))
38+
assert z.index((i32(1), 'a', f64(2.00))) == 3
39+
40+
test_list_index()

src/libasr/ASR.asdl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ expr
255255
| ListConcat(expr left, expr right, ttype type, expr? value)
256256
| ListCompare(expr left, cmpop op, expr right, ttype type, expr? value)
257257
| ListCount(expr arg, expr ele, ttype type, expr? value)
258+
| ListIndex(expr arg, expr ele, ttype type, expr? value)
258259

259260
| SetConstant(expr* elements, ttype type)
260261
| SetLen(expr arg, ttype type, expr? value)

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,6 +1962,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
19621962
tmp = list_api->count(plist, item, asr_el_type, *module);
19631963
}
19641964

1965+
void visit_ListIndex(const ASR::ListIndex_t& x) {
1966+
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_arg));
1967+
int64_t ptr_loads_copy = ptr_loads;
1968+
ptr_loads = 0;
1969+
this->visit_expr(*x.m_arg);
1970+
llvm::Value* plist = tmp;
1971+
1972+
ptr_loads = !LLVM::is_llvm_struct(asr_el_type);
1973+
this->visit_expr_wrapper(x.m_ele, true);
1974+
ptr_loads = ptr_loads_copy;
1975+
llvm::Value *item = tmp;
1976+
tmp = list_api->index(plist, item, asr_el_type, *module);
1977+
}
1978+
19651979
void visit_ListClear(const ASR::ListClear_t& x) {
19661980
int64_t ptr_loads_copy = ptr_loads;
19671981
ptr_loads = 0;

src/libasr/codegen/llvm_utils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,6 +2498,11 @@ namespace LCompilers {
24982498
return LLVM::CreateLoad(*builder, i);
24992499
}
25002500

2501+
llvm::Value* LLVMList::index(llvm::Value* list, llvm::Value* item,
2502+
ASR::ttype_t* item_type, llvm::Module& module) {
2503+
return LLVMList::find_item_position(list, item, item_type, module);
2504+
}
2505+
25012506
llvm::Value* LLVMList::count(llvm::Value* list, llvm::Value* item,
25022507
ASR::ttype_t* item_type, llvm::Module& module) {
25032508
llvm::Type* pos_type = llvm::Type::getInt32Ty(context);

src/libasr/codegen/llvm_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ namespace LCompilers {
233233
llvm::Value* item, ASR::ttype_t* item_type,
234234
llvm::Module& module);
235235

236+
llvm::Value* index(llvm::Value* list, llvm::Value* item,
237+
ASR::ttype_t* item_type, llvm::Module& module);
238+
236239
llvm::Value* count(llvm::Value* list, llvm::Value* item,
237240
ASR::ttype_t* item_type, llvm::Module& module);
238241

src/lpython/semantics/python_attribute_eval.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct AttributeHandler {
2222
{"list@append", &eval_list_append},
2323
{"list@remove", &eval_list_remove},
2424
{"list@count", &eval_list_count},
25+
{"list@index", &eval_list_index},
2526
{"list@clear", &eval_list_clear},
2627
{"list@insert", &eval_list_insert},
2728
{"list@pop", &eval_list_pop},
@@ -149,6 +150,32 @@ struct AttributeHandler {
149150
return make_ListCount_t(al, loc, s, args[0], to_type, nullptr);
150151
}
151152

153+
static ASR::asr_t* eval_list_index(ASR::expr_t *s, Allocator &al, const Location &loc,
154+
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
155+
if (args.size() != 1) {
156+
throw SemanticError("index() takes exactly one argument",
157+
loc);
158+
}
159+
ASR::ttype_t *type = ASRUtils::expr_type(s);
160+
ASR::ttype_t *list_type = ASR::down_cast<ASR::List_t>(type)->m_type;
161+
ASR::ttype_t *ele_type = ASRUtils::expr_type(args[0]);
162+
if (!ASRUtils::check_equal_type(ele_type, list_type)) {
163+
std::string fnd = ASRUtils::type_to_str_python(ele_type);
164+
std::string org = ASRUtils::type_to_str_python(list_type);
165+
diag.add(diag::Diagnostic(
166+
"Type mismatch in 'index', the types must be compatible",
167+
diag::Level::Error, diag::Stage::Semantic, {
168+
diag::Label("type mismatch (found: '" + fnd + "', expected: '" + org + "')",
169+
{args[0]->base.loc})
170+
})
171+
);
172+
throw SemanticAbort();
173+
}
174+
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc,
175+
4, nullptr, 0));
176+
return make_ListIndex_t(al, loc, s, args[0], to_type, nullptr);
177+
}
178+
152179
static ASR::asr_t* eval_list_insert(ASR::expr_t *s, Allocator &al, const Location &loc,
153180
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
154181
if (args.size() != 2) {

tests/errors/test_list_index.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from lpython import i32
2+
3+
def test_list_index_error():
4+
a: list[i32]
5+
a = [1, 2, 3]
6+
# a.index(1.0) # type mismatch
7+
print(a.index(0)) # no error?
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-test_list_index-6b8f30b",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/errors/test_list_index.py",
5+
"infile_hash": "5519aabe3c897ad8601470e2ef732f194784dfb17d0eb1833a14c9e0",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": "asr-test_list_index-6b8f30b.stdout",
9+
"stdout_hash": "83ff2357edbe8f33401839aa12f16c2288fa4a242b24130b0b0536f6",
10+
"stderr": null,
11+
"stderr_hash": null,
12+
"returncode": 0
13+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
(TranslationUnit (SymbolTable 1 {main_program: (Program (SymbolTable 3 {}) main_program [] []), test_list_index_error: (Function (SymbolTable 2 {a: (Variable 2 a [] Local () () Default (List (Integer 4 [])) Source Public Required .false.), b: (Variable 2 b [] Local () () Default (Integer 4 []) Source Public Required .false.)}) test_list_index_error (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [] [(= (Var 2 a) (ListConstant [(IntegerConstant 1 (Integer 4 [])) (IntegerConstant 2 (Integer 4 [])) (IntegerConstant 3 (Integer 4 []))] (List (Integer 4 []))) ()) (= (Var 2 b) (ListIndex (Var 2 a) (IntegerConstant 0 (Integer 4 [])) (Integer 4 []) ()) ())] () Public .false. .false.)}) [])

tests/tests.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,10 @@ asr = true
676676
filename = "errors/test_list_count.py"
677677
asr = true
678678

679+
[[test]]
680+
filename = "errors/test_list_index.py"
681+
asr = true
682+
679683
[[test]]
680684
filename = "errors/test_list1.py"
681685
asr = true

0 commit comments

Comments
 (0)