Skip to content

Commit 1efcc7f

Browse files
authored
Merge pull request #1676 from virendrakabra14/scratch
Add list.count
2 parents 7d08405 + ec62fdd commit 1efcc7f

11 files changed

+206
-0
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ RUN(NAME test_list_08 LABELS cpython llvm c)
280280
RUN(NAME test_list_09 LABELS cpython llvm c)
281281
RUN(NAME test_list_10 LABELS cpython llvm c)
282282
RUN(NAME test_list_section LABELS cpython llvm c)
283+
RUN(NAME test_list_count LABELS cpython llvm)
283284
RUN(NAME test_tuple_01 LABELS cpython llvm c)
284285
RUN(NAME test_tuple_02 LABELS cpython llvm c)
285286
RUN(NAME test_tuple_03 LABELS cpython llvm c)

integration_tests/test_list_count.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from lpython import i32, f64
2+
3+
def test_list_count():
4+
i: i32
5+
x: list[i32] = []
6+
y: list[str] = []
7+
z: list[tuple[i32, str, f64]] = []
8+
9+
for i in range(-5, 0):
10+
assert x.count(i) == 0
11+
x.append(i)
12+
assert x.count(i) == 1
13+
x.append(i)
14+
assert x.count(i) == 2
15+
x.remove(i)
16+
assert x.count(i) == 1
17+
18+
assert x == [-5, -4, -3, -2, -1]
19+
20+
for i in range(0, 5):
21+
assert x.count(i) == 0
22+
x.append(i)
23+
assert x.count(i) == 1
24+
25+
assert x == [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
26+
27+
while len(x) > 0:
28+
i = x[-1]
29+
x.remove(i)
30+
assert x.count(i) == 0
31+
32+
assert len(x) == 0
33+
assert x.count(0) == 0
34+
35+
# str
36+
assert y.count('a') == 0
37+
y = ['a', 'abc', 'a', 'b']
38+
assert y.count('a') == 2
39+
y.append('a')
40+
assert y.count('a') == 3
41+
y.remove('a')
42+
assert y.count('a') == 2
43+
44+
# tuple, float
45+
assert z.count((i32(-1), 'b', f64(2))) == 0
46+
z = [(i32(1), 'a', f64(2.01)), (i32(-1), 'b', f64(2)), (i32(1), 'a', f64(2.02))]
47+
assert z.count((i32(1), 'a', f64(2.00))) == 0
48+
assert z.count((i32(1), 'a', f64(2.01))) == 1
49+
z.append((i32(1), 'a', f64(2)))
50+
z.append((i32(1), 'a', f64(2.00)))
51+
assert z.count((i32(1), 'a', f64(2))) == 2
52+
z.remove((i32(1), 'a', f64(2)))
53+
assert z.count((i32(1), 'a', f64(2.00))) == 1
54+
55+
test_list_count()

src/libasr/ASR.asdl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ expr
254254
| ListLen(expr arg, ttype type, expr? value)
255255
| ListConcat(expr left, expr right, ttype type, expr? value)
256256
| ListCompare(expr left, cmpop op, expr right, ttype type, expr? value)
257+
| ListCount(expr arg, expr ele, ttype type, expr? value)
257258

258259
| SetConstant(expr* elements, ttype type)
259260
| SetLen(expr arg, ttype type, expr? value)

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,6 +1948,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
19481948
list_api->remove(plist, item, asr_el_type, *module);
19491949
}
19501950

1951+
void visit_ListCount(const ASR::ListCount_t& x) {
1952+
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_arg));
1953+
int64_t ptr_loads_copy = ptr_loads;
1954+
ptr_loads = 0;
1955+
this->visit_expr(*x.m_arg);
1956+
llvm::Value* plist = tmp;
1957+
1958+
ptr_loads = !LLVM::is_llvm_struct(asr_el_type);
1959+
this->visit_expr_wrapper(x.m_ele, true);
1960+
ptr_loads = ptr_loads_copy;
1961+
llvm::Value *item = tmp;
1962+
tmp = list_api->count(plist, item, asr_el_type, *module);
1963+
}
1964+
19511965
void visit_ListClear(const ASR::ListClear_t& x) {
19521966
int64_t ptr_loads_copy = ptr_loads;
19531967
ptr_loads = 0;

src/libasr/codegen/llvm_utils.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,6 +2498,83 @@ namespace LCompilers {
24982498
return LLVM::CreateLoad(*builder, i);
24992499
}
25002500

2501+
llvm::Value* LLVMList::count(llvm::Value* list, llvm::Value* item,
2502+
ASR::ttype_t* item_type, llvm::Module& module) {
2503+
llvm::Type* pos_type = llvm::Type::getInt32Ty(context);
2504+
llvm::Value* current_end_point = LLVM::CreateLoad(*builder,
2505+
get_pointer_to_current_end_point(list));
2506+
llvm::AllocaInst *i = builder->CreateAlloca(pos_type, nullptr);
2507+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
2508+
context, llvm::APInt(32, 0)), i);
2509+
llvm::AllocaInst *cnt = builder->CreateAlloca(pos_type, nullptr);
2510+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
2511+
context, llvm::APInt(32, 0)), cnt);
2512+
llvm::Value* tmp = nullptr;
2513+
2514+
/* Equivalent in C++:
2515+
* int i = 0;
2516+
* int cnt = 0;
2517+
* while(end_point > i) {
2518+
* if(list[i] == item) {
2519+
* tmp = cnt+1;
2520+
* cnt = tmp;
2521+
* }
2522+
* tmp = i+1;
2523+
* i = tmp;
2524+
* }
2525+
*/
2526+
2527+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
2528+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
2529+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
2530+
2531+
// head
2532+
llvm_utils->start_new_block(loophead);
2533+
{
2534+
llvm::Value *cond = builder->CreateICmpSGT(current_end_point,
2535+
LLVM::CreateLoad(*builder, i));
2536+
builder->CreateCondBr(cond, loopbody, loopend);
2537+
}
2538+
2539+
// body
2540+
llvm_utils->start_new_block(loopbody);
2541+
{
2542+
// if occurrence found, increment cnt
2543+
llvm::Function *fn = builder->GetInsertBlock()->getParent();
2544+
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
2545+
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
2546+
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
2547+
2548+
llvm::Value* left_arg = read_item(list, LLVM::CreateLoad(*builder, i),
2549+
false, module, LLVM::is_llvm_struct(item_type));
2550+
llvm::Value* cond = llvm_utils->is_equal_by_value(left_arg, item, module, item_type);
2551+
builder->CreateCondBr(cond, thenBB, elseBB);
2552+
builder->SetInsertPoint(thenBB);
2553+
{
2554+
tmp = builder->CreateAdd(
2555+
LLVM::CreateLoad(*builder, cnt),
2556+
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
2557+
LLVM::CreateStore(*builder, tmp, cnt);
2558+
}
2559+
builder->CreateBr(mergeBB);
2560+
2561+
llvm_utils->start_new_block(elseBB);
2562+
llvm_utils->start_new_block(mergeBB);
2563+
2564+
// increment i
2565+
tmp = builder->CreateAdd(
2566+
LLVM::CreateLoad(*builder, i),
2567+
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
2568+
LLVM::CreateStore(*builder, tmp, i);
2569+
}
2570+
builder->CreateBr(loophead);
2571+
2572+
// end
2573+
llvm_utils->start_new_block(loopend);
2574+
2575+
return LLVM::CreateLoad(*builder, cnt);
2576+
}
2577+
25012578
void LLVMList::remove(llvm::Value* list, llvm::Value* item,
25022579
ASR::ttype_t* item_type, llvm::Module& module) {
25032580
llvm::Type* pos_type = llvm::Type::getInt32Ty(context);

src/libasr/codegen/llvm_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ namespace LCompilers {
233233
llvm::Value* item, ASR::ttype_t* item_type,
234234
llvm::Module& module);
235235

236+
llvm::Value* count(llvm::Value* list, llvm::Value* item,
237+
ASR::ttype_t* item_type, llvm::Module& module);
238+
236239
void free_data(llvm::Value* list, llvm::Module& module);
237240

238241
llvm::Value* check_list_equality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type,

src/lpython/semantics/python_attribute_eval.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct AttributeHandler {
2121
{"int@bit_length", &eval_int_bit_length},
2222
{"list@append", &eval_list_append},
2323
{"list@remove", &eval_list_remove},
24+
{"list@count", &eval_list_count},
2425
{"list@clear", &eval_list_clear},
2526
{"list@insert", &eval_list_insert},
2627
{"list@pop", &eval_list_pop},
@@ -122,6 +123,32 @@ struct AttributeHandler {
122123
return make_ListRemove_t(al, loc, s, args[0]);
123124
}
124125

126+
static ASR::asr_t* eval_list_count(ASR::expr_t *s, Allocator &al, const Location &loc,
127+
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
128+
if (args.size() != 1) {
129+
throw SemanticError("count() takes exactly one argument",
130+
loc);
131+
}
132+
ASR::ttype_t *type = ASRUtils::expr_type(s);
133+
ASR::ttype_t *list_type = ASR::down_cast<ASR::List_t>(type)->m_type;
134+
ASR::ttype_t *ele_type = ASRUtils::expr_type(args[0]);
135+
if (!ASRUtils::check_equal_type(ele_type, list_type)) {
136+
std::string fnd = ASRUtils::type_to_str_python(ele_type);
137+
std::string org = ASRUtils::type_to_str_python(list_type);
138+
diag.add(diag::Diagnostic(
139+
"Type mismatch in 'count', the types must be compatible",
140+
diag::Level::Error, diag::Stage::Semantic, {
141+
diag::Label("type mismatch (found: '" + fnd + "', expected: '" + org + "')",
142+
{args[0]->base.loc})
143+
})
144+
);
145+
throw SemanticAbort();
146+
}
147+
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc,
148+
4, nullptr, 0));
149+
return make_ListCount_t(al, loc, s, args[0], to_type, nullptr);
150+
}
151+
125152
static ASR::asr_t* eval_list_insert(ASR::expr_t *s, Allocator &al, const Location &loc,
126153
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
127154
if (args.size() != 2) {

tests/errors/test_list_count.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from lpython import i32
2+
3+
def test_list_count_error():
4+
a: list[i32]
5+
a = [1, 2, 3]
6+
a.count(1.0)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-test_list_count-4b42498",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/errors/test_list_count.py",
5+
"infile_hash": "01975bd7c4bba02fd811de536b218167da99b532fa955b7bf8339779",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "asr-test_list_count-4b42498.stderr",
11+
"stderr_hash": "f26efcc623b68ca43ef871eb01c8e3cbd1ae464baaa491c6e4969696",
12+
"returncode": 2
13+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
semantic error: Type mismatch in 'count', the types must be compatible
2+
--> tests/errors/test_list_count.py:6:13
3+
|
4+
6 | a.count(1.0)
5+
| ^^^ type mismatch (found: 'f64', expected: 'i32')

tests/tests.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,10 @@ asr = true
676676
filename = "errors/test_list_concat.py"
677677
asr = true
678678

679+
[[test]]
680+
filename = "errors/test_list_count.py"
681+
asr = true
682+
679683
[[test]]
680684
filename = "errors/test_list1.py"
681685
asr = true

0 commit comments

Comments
 (0)