diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index b9e40ed31a..c0671c71c0 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -280,6 +280,7 @@ RUN(NAME test_list_08 LABELS cpython llvm c) RUN(NAME test_list_09 LABELS cpython llvm c) RUN(NAME test_list_10 LABELS cpython llvm c) RUN(NAME test_list_section LABELS cpython llvm c) +RUN(NAME test_list_count LABELS cpython llvm) RUN(NAME test_tuple_01 LABELS cpython llvm c) RUN(NAME test_tuple_02 LABELS cpython llvm c) RUN(NAME test_tuple_03 LABELS cpython llvm c) diff --git a/integration_tests/test_list_count.py b/integration_tests/test_list_count.py new file mode 100644 index 0000000000..98c64fd6cb --- /dev/null +++ b/integration_tests/test_list_count.py @@ -0,0 +1,55 @@ +from lpython import i32, f64 + +def test_list_count(): + i: i32 + x: list[i32] = [] + y: list[str] = [] + z: list[tuple[i32, str, f64]] = [] + + for i in range(-5, 0): + assert x.count(i) == 0 + x.append(i) + assert x.count(i) == 1 + x.append(i) + assert x.count(i) == 2 + x.remove(i) + assert x.count(i) == 1 + + assert x == [-5, -4, -3, -2, -1] + + for i in range(0, 5): + assert x.count(i) == 0 + x.append(i) + assert x.count(i) == 1 + + assert x == [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] + + while len(x) > 0: + i = x[-1] + x.remove(i) + assert x.count(i) == 0 + + assert len(x) == 0 + assert x.count(0) == 0 + + # str + assert y.count('a') == 0 + y = ['a', 'abc', 'a', 'b'] + assert y.count('a') == 2 + y.append('a') + assert y.count('a') == 3 + y.remove('a') + assert y.count('a') == 2 + + # tuple, float + assert z.count((i32(-1), 'b', f64(2))) == 0 + z = [(i32(1), 'a', f64(2.01)), (i32(-1), 'b', f64(2)), (i32(1), 'a', f64(2.02))] + assert z.count((i32(1), 'a', f64(2.00))) == 0 + assert z.count((i32(1), 'a', f64(2.01))) == 1 + z.append((i32(1), 'a', f64(2))) + z.append((i32(1), 'a', f64(2.00))) + assert z.count((i32(1), 'a', f64(2))) == 2 + z.remove((i32(1), 'a', f64(2))) + assert z.count((i32(1), 'a', f64(2.00))) == 1 + +test_list_count() diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index cc060aca9d..34af5be3cd 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -254,6 +254,7 @@ expr | ListLen(expr arg, ttype type, expr? value) | ListConcat(expr left, expr right, ttype type, expr? value) | ListCompare(expr left, cmpop op, expr right, ttype type, expr? value) + | ListCount(expr arg, expr ele, ttype type, expr? value) | SetConstant(expr* elements, ttype type) | SetLen(expr arg, ttype type, expr? value) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 0e12c89555..ab54f53a78 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1948,6 +1948,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor list_api->remove(plist, item, asr_el_type, *module); } + void visit_ListCount(const ASR::ListCount_t& x) { + ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_arg)); + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_arg); + llvm::Value* plist = tmp; + + ptr_loads = !LLVM::is_llvm_struct(asr_el_type); + this->visit_expr_wrapper(x.m_ele, true); + ptr_loads = ptr_loads_copy; + llvm::Value *item = tmp; + tmp = list_api->count(plist, item, asr_el_type, *module); + } + void visit_ListClear(const ASR::ListClear_t& x) { int64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index bc790d5974..a061e26f6a 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -2498,6 +2498,83 @@ namespace LCompilers { return LLVM::CreateLoad(*builder, i); } + llvm::Value* LLVMList::count(llvm::Value* list, llvm::Value* item, + ASR::ttype_t* item_type, llvm::Module& module) { + llvm::Type* pos_type = llvm::Type::getInt32Ty(context); + llvm::Value* current_end_point = LLVM::CreateLoad(*builder, + get_pointer_to_current_end_point(list)); + llvm::AllocaInst *i = builder->CreateAlloca(pos_type, nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get( + context, llvm::APInt(32, 0)), i); + llvm::AllocaInst *cnt = builder->CreateAlloca(pos_type, nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get( + context, llvm::APInt(32, 0)), cnt); + llvm::Value* tmp = nullptr; + + /* Equivalent in C++: + * int i = 0; + * int cnt = 0; + * while(end_point > i) { + * if(list[i] == item) { + * tmp = cnt+1; + * cnt = tmp; + * } + * tmp = i+1; + * i = tmp; + * } + */ + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT(current_end_point, + LLVM::CreateLoad(*builder, i)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + // if occurrence found, increment cnt + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + + llvm::Value* left_arg = read_item(list, LLVM::CreateLoad(*builder, i), + false, module, LLVM::is_llvm_struct(item_type)); + llvm::Value* cond = llvm_utils->is_equal_by_value(left_arg, item, module, item_type); + builder->CreateCondBr(cond, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + tmp = builder->CreateAdd( + LLVM::CreateLoad(*builder, cnt), + llvm::ConstantInt::get(context, llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, tmp, cnt); + } + builder->CreateBr(mergeBB); + + llvm_utils->start_new_block(elseBB); + llvm_utils->start_new_block(mergeBB); + + // increment i + tmp = builder->CreateAdd( + LLVM::CreateLoad(*builder, i), + llvm::ConstantInt::get(context, llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, tmp, i); + } + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + + return LLVM::CreateLoad(*builder, cnt); + } + void LLVMList::remove(llvm::Value* list, llvm::Value* item, ASR::ttype_t* item_type, llvm::Module& module) { llvm::Type* pos_type = llvm::Type::getInt32Ty(context); diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 5a35bac52f..f684117d84 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -232,6 +232,9 @@ namespace LCompilers { llvm::Value* item, ASR::ttype_t* item_type, llvm::Module& module); + llvm::Value* count(llvm::Value* list, llvm::Value* item, + ASR::ttype_t* item_type, llvm::Module& module); + void free_data(llvm::Value* list, llvm::Module& module); llvm::Value* check_list_equality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type, diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index f16961e23a..a943ff0ca4 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -21,6 +21,7 @@ struct AttributeHandler { {"int@bit_length", &eval_int_bit_length}, {"list@append", &eval_list_append}, {"list@remove", &eval_list_remove}, + {"list@count", &eval_list_count}, {"list@clear", &eval_list_clear}, {"list@insert", &eval_list_insert}, {"list@pop", &eval_list_pop}, @@ -122,6 +123,32 @@ struct AttributeHandler { return make_ListRemove_t(al, loc, s, args[0]); } + static ASR::asr_t* eval_list_count(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &diag) { + if (args.size() != 1) { + throw SemanticError("count() takes exactly one argument", + loc); + } + ASR::ttype_t *type = ASRUtils::expr_type(s); + ASR::ttype_t *list_type = ASR::down_cast(type)->m_type; + ASR::ttype_t *ele_type = ASRUtils::expr_type(args[0]); + if (!ASRUtils::check_equal_type(ele_type, list_type)) { + std::string fnd = ASRUtils::type_to_str_python(ele_type); + std::string org = ASRUtils::type_to_str_python(list_type); + diag.add(diag::Diagnostic( + "Type mismatch in 'count', the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch (found: '" + fnd + "', expected: '" + org + "')", + {args[0]->base.loc}) + }) + ); + throw SemanticAbort(); + } + ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, + 4, nullptr, 0)); + return make_ListCount_t(al, loc, s, args[0], to_type, nullptr); + } + static ASR::asr_t* eval_list_insert(ASR::expr_t *s, Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &diag) { if (args.size() != 2) { diff --git a/tests/errors/test_list_count.py b/tests/errors/test_list_count.py new file mode 100644 index 0000000000..90c0f50d2d --- /dev/null +++ b/tests/errors/test_list_count.py @@ -0,0 +1,6 @@ +from lpython import i32 + +def test_list_count_error(): + a: list[i32] + a = [1, 2, 3] + a.count(1.0) \ No newline at end of file diff --git a/tests/reference/asr-test_list_count-4b42498.json b/tests/reference/asr-test_list_count-4b42498.json new file mode 100644 index 0000000000..f4864b55fb --- /dev/null +++ b/tests/reference/asr-test_list_count-4b42498.json @@ -0,0 +1,13 @@ +{ + "basename": "asr-test_list_count-4b42498", + "cmd": "lpython --show-asr --no-color {infile} -o {outfile}", + "infile": "tests/errors/test_list_count.py", + "infile_hash": "01975bd7c4bba02fd811de536b218167da99b532fa955b7bf8339779", + "outfile": null, + "outfile_hash": null, + "stdout": null, + "stdout_hash": null, + "stderr": "asr-test_list_count-4b42498.stderr", + "stderr_hash": "f26efcc623b68ca43ef871eb01c8e3cbd1ae464baaa491c6e4969696", + "returncode": 2 +} \ No newline at end of file diff --git a/tests/reference/asr-test_list_count-4b42498.stderr b/tests/reference/asr-test_list_count-4b42498.stderr new file mode 100644 index 0000000000..ad60a50f0e --- /dev/null +++ b/tests/reference/asr-test_list_count-4b42498.stderr @@ -0,0 +1,5 @@ +semantic error: Type mismatch in 'count', the types must be compatible + --> tests/errors/test_list_count.py:6:13 + | +6 | a.count(1.0) + | ^^^ type mismatch (found: 'f64', expected: 'i32') diff --git a/tests/tests.toml b/tests/tests.toml index f32230290a..1755dcb5e8 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -676,6 +676,10 @@ asr = true filename = "errors/test_list_concat.py" asr = true +[[test]] +filename = "errors/test_list_count.py" +asr = true + [[test]] filename = "errors/test_list1.py" asr = true