diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index a8b92b29cd..9cee6a76ad 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1589,6 +1589,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } + void visit_ListCompare(const ASR::ListCompare_t x) { + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_left); + llvm::Value* left = tmp; + this->visit_expr(*x.m_right); + llvm::Value* right = tmp; + ptr_loads = ptr_loads_copy; + tmp = llvm_utils->is_equal_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left)); + if (x.m_op == ASR::cmpopType::NotEq) { + tmp = builder->CreateNot(tmp); + } + } + void visit_DictLen(const ASR::DictLen_t& x) { if (x.m_value) { this->visit_expr(*x.m_value); @@ -1671,6 +1686,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor list_api->list_clear(plist); } + void visit_TupleCompare(const ASR::TupleCompare_t& x) { + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_left); + llvm::Value* left = tmp; + this->visit_expr(*x.m_right); + llvm::Value* right = tmp; + ptr_loads = ptr_loads_copy; + tmp = llvm_utils->is_equal_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left)); + if (x.m_op == ASR::cmpopType::NotEq) { + tmp = builder->CreateNot(tmp); + } + } + void visit_TupleLen(const ASR::TupleLen_t& x) { LFORTRAN_ASSERT(x.m_value); this->visit_expr(*x.m_value); diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 9a09774074..5d9e0f71d4 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -236,7 +236,10 @@ namespace LFortran { switch( asr_type->type ) { case ASR::ttypeType::Integer: { return builder->CreateICmpEQ(left, right); - }; + } + case ASR::ttypeType::Logical: { + return builder->CreateICmpEQ(left, right); + } case ASR::ttypeType::Real: { return builder->CreateFCmpOEQ(left, right); } @@ -291,6 +294,11 @@ namespace LFortran { return tuple_api->check_tuple_equality(left, right, tuple_type, context, builder, module); } + case ASR::ttypeType::List: { + ASR::List_t* list_type = ASR::down_cast(asr_type); + return list_api->check_list_equality(left, right, list_type->m_type, + context, builder, module); + } default: { throw LCompilersException("LLVMUtils::is_equal_by_value isn't implemented for " + ASRUtils::type_to_str_python(asr_type)); @@ -2558,6 +2566,65 @@ namespace LFortran { LLVM::lfortran_free(context, module, *builder, data); } + llvm::Value* LLVMList::check_list_equality(llvm::Value* l1, llvm::Value* l2, + ASR::ttype_t* item_type, + llvm::LLVMContext& context, + llvm::IRBuilder<>* builder, + llvm::Module& module) { + llvm::AllocaInst *is_equal = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 1)), is_equal); + llvm::Value *a_len = llvm_utils->list_api->len(l1); + llvm::Value *b_len = llvm_utils->list_api->len(l2); + llvm::Value *cond = builder->CreateICmpEQ(a_len, b_len); + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + builder->CreateCondBr(cond, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + llvm::AllocaInst *idx = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get( + context, llvm::APInt(32, 0)), idx); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* cnd = builder->CreateICmpSLT(i, a_len); + builder->CreateCondBr(cnd, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* left_arg = llvm_utils->list_api->read_item(l1, i, + false, module, LLVM::is_llvm_struct(item_type)); + llvm::Value* right_arg = llvm_utils->list_api->read_item(l2, i, + false, module, LLVM::is_llvm_struct(item_type)); + llvm::Value* res = llvm_utils->is_equal_by_value(left_arg, right_arg, module, + item_type); + res = builder->CreateAnd(LLVM::CreateLoad(*builder, is_equal), res); + LLVM::CreateStore(*builder, res, is_equal); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, i, idx); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 0)), is_equal); + llvm_utils->start_new_block(mergeBB); + return LLVM::CreateLoad(*builder, is_equal); + } LLVMTuple::LLVMTuple(llvm::LLVMContext& context_, LLVMUtils* llvm_utils_, @@ -2619,8 +2686,10 @@ namespace LFortran { llvm::Module& module) { llvm::Value* is_equal = llvm::ConstantInt::get(context, llvm::APInt(1, 1)); for( size_t i = 0; i < tuple_type->n_type; i++ ) { - llvm::Value* t1i = llvm_utils->tuple_api->read_item(t1, i); - llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i); + llvm::Value* t1i = llvm_utils->tuple_api->read_item(t1, i, LLVM::is_llvm_struct( + tuple_type->m_type[i])); + llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i, LLVM::is_llvm_struct( + tuple_type->m_type[i])); llvm::Value* is_t1_eq_t2 = llvm_utils->is_equal_by_value(t1i, t2i, module, tuple_type->m_type[i]); is_equal = builder->CreateAnd(is_equal, is_t1_eq_t2); diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 2e89424971..1c87bbe950 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -212,6 +212,9 @@ namespace LFortran { llvm::Module& module); void free_data(llvm::Value* list, llvm::Module& module); + + llvm::Value* check_list_equality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type, + llvm::LLVMContext& context, llvm::IRBuilder<>* builder, llvm::Module& module); }; class LLVMTuple { diff --git a/src/libasr/pass/pass_manager.h b/src/libasr/pass/pass_manager.h index 6bddd46df6..e1ebf1e4ce 100644 --- a/src/libasr/pass/pass_manager.h +++ b/src/libasr/pass/pass_manager.h @@ -129,7 +129,6 @@ namespace LCompilers { "select_case", "inline_function_calls", "unused_functions", - "pass_compare" }; _with_optimization_passes = {