Skip to content

Commit 269a92d

Browse files
authored
Merge pull request #1359 from Smit-create/i-1245-2
[LLVM]: Implement Tuple/List Compare
2 parents 5e5380b + 892b085 commit 269a92d

File tree

4 files changed

+105
-4
lines changed

4 files changed

+105
-4
lines changed

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15891589
}
15901590
}
15911591

1592+
void visit_ListCompare(const ASR::ListCompare_t x) {
1593+
int64_t ptr_loads_copy = ptr_loads;
1594+
ptr_loads = 0;
1595+
this->visit_expr(*x.m_left);
1596+
llvm::Value* left = tmp;
1597+
this->visit_expr(*x.m_right);
1598+
llvm::Value* right = tmp;
1599+
ptr_loads = ptr_loads_copy;
1600+
tmp = llvm_utils->is_equal_by_value(left, right, *module,
1601+
ASRUtils::expr_type(x.m_left));
1602+
if (x.m_op == ASR::cmpopType::NotEq) {
1603+
tmp = builder->CreateNot(tmp);
1604+
}
1605+
}
1606+
15921607
void visit_DictLen(const ASR::DictLen_t& x) {
15931608
if (x.m_value) {
15941609
this->visit_expr(*x.m_value);
@@ -1671,6 +1686,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16711686
list_api->list_clear(plist);
16721687
}
16731688

1689+
void visit_TupleCompare(const ASR::TupleCompare_t& x) {
1690+
int64_t ptr_loads_copy = ptr_loads;
1691+
ptr_loads = 0;
1692+
this->visit_expr(*x.m_left);
1693+
llvm::Value* left = tmp;
1694+
this->visit_expr(*x.m_right);
1695+
llvm::Value* right = tmp;
1696+
ptr_loads = ptr_loads_copy;
1697+
tmp = llvm_utils->is_equal_by_value(left, right, *module,
1698+
ASRUtils::expr_type(x.m_left));
1699+
if (x.m_op == ASR::cmpopType::NotEq) {
1700+
tmp = builder->CreateNot(tmp);
1701+
}
1702+
}
1703+
16741704
void visit_TupleLen(const ASR::TupleLen_t& x) {
16751705
LFORTRAN_ASSERT(x.m_value);
16761706
this->visit_expr(*x.m_value);

src/libasr/codegen/llvm_utils.cpp

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,10 @@ namespace LFortran {
236236
switch( asr_type->type ) {
237237
case ASR::ttypeType::Integer: {
238238
return builder->CreateICmpEQ(left, right);
239-
};
239+
}
240+
case ASR::ttypeType::Logical: {
241+
return builder->CreateICmpEQ(left, right);
242+
}
240243
case ASR::ttypeType::Real: {
241244
return builder->CreateFCmpOEQ(left, right);
242245
}
@@ -291,6 +294,11 @@ namespace LFortran {
291294
return tuple_api->check_tuple_equality(left, right, tuple_type, context,
292295
builder, module);
293296
}
297+
case ASR::ttypeType::List: {
298+
ASR::List_t* list_type = ASR::down_cast<ASR::List_t>(asr_type);
299+
return list_api->check_list_equality(left, right, list_type->m_type,
300+
context, builder, module);
301+
}
294302
default: {
295303
throw LCompilersException("LLVMUtils::is_equal_by_value isn't implemented for " +
296304
ASRUtils::type_to_str_python(asr_type));
@@ -2558,6 +2566,65 @@ namespace LFortran {
25582566
LLVM::lfortran_free(context, module, *builder, data);
25592567
}
25602568

2569+
llvm::Value* LLVMList::check_list_equality(llvm::Value* l1, llvm::Value* l2,
2570+
ASR::ttype_t* item_type,
2571+
llvm::LLVMContext& context,
2572+
llvm::IRBuilder<>* builder,
2573+
llvm::Module& module) {
2574+
llvm::AllocaInst *is_equal = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
2575+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 1)), is_equal);
2576+
llvm::Value *a_len = llvm_utils->list_api->len(l1);
2577+
llvm::Value *b_len = llvm_utils->list_api->len(l2);
2578+
llvm::Value *cond = builder->CreateICmpEQ(a_len, b_len);
2579+
llvm::Function *fn = builder->GetInsertBlock()->getParent();
2580+
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
2581+
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
2582+
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
2583+
builder->CreateCondBr(cond, thenBB, elseBB);
2584+
builder->SetInsertPoint(thenBB);
2585+
llvm::AllocaInst *idx = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
2586+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
2587+
context, llvm::APInt(32, 0)), idx);
2588+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
2589+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
2590+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
2591+
2592+
// head
2593+
llvm_utils->start_new_block(loophead);
2594+
{
2595+
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
2596+
llvm::Value* cnd = builder->CreateICmpSLT(i, a_len);
2597+
builder->CreateCondBr(cnd, loopbody, loopend);
2598+
}
2599+
2600+
// body
2601+
llvm_utils->start_new_block(loopbody);
2602+
{
2603+
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
2604+
llvm::Value* left_arg = llvm_utils->list_api->read_item(l1, i,
2605+
false, module, LLVM::is_llvm_struct(item_type));
2606+
llvm::Value* right_arg = llvm_utils->list_api->read_item(l2, i,
2607+
false, module, LLVM::is_llvm_struct(item_type));
2608+
llvm::Value* res = llvm_utils->is_equal_by_value(left_arg, right_arg, module,
2609+
item_type);
2610+
res = builder->CreateAnd(LLVM::CreateLoad(*builder, is_equal), res);
2611+
LLVM::CreateStore(*builder, res, is_equal);
2612+
i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
2613+
llvm::APInt(32, 1)));
2614+
LLVM::CreateStore(*builder, i, idx);
2615+
}
2616+
2617+
builder->CreateBr(loophead);
2618+
2619+
// end
2620+
llvm_utils->start_new_block(loopend);
2621+
2622+
builder->CreateBr(mergeBB);
2623+
llvm_utils->start_new_block(elseBB);
2624+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 0)), is_equal);
2625+
llvm_utils->start_new_block(mergeBB);
2626+
return LLVM::CreateLoad(*builder, is_equal);
2627+
}
25612628

25622629
LLVMTuple::LLVMTuple(llvm::LLVMContext& context_,
25632630
LLVMUtils* llvm_utils_,
@@ -2619,8 +2686,10 @@ namespace LFortran {
26192686
llvm::Module& module) {
26202687
llvm::Value* is_equal = llvm::ConstantInt::get(context, llvm::APInt(1, 1));
26212688
for( size_t i = 0; i < tuple_type->n_type; i++ ) {
2622-
llvm::Value* t1i = llvm_utils->tuple_api->read_item(t1, i);
2623-
llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i);
2689+
llvm::Value* t1i = llvm_utils->tuple_api->read_item(t1, i, LLVM::is_llvm_struct(
2690+
tuple_type->m_type[i]));
2691+
llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i, LLVM::is_llvm_struct(
2692+
tuple_type->m_type[i]));
26242693
llvm::Value* is_t1_eq_t2 = llvm_utils->is_equal_by_value(t1i, t2i, module,
26252694
tuple_type->m_type[i]);
26262695
is_equal = builder->CreateAnd(is_equal, is_t1_eq_t2);

src/libasr/codegen/llvm_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ namespace LFortran {
212212
llvm::Module& module);
213213

214214
void free_data(llvm::Value* list, llvm::Module& module);
215+
216+
llvm::Value* check_list_equality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type,
217+
llvm::LLVMContext& context, llvm::IRBuilder<>* builder, llvm::Module& module);
215218
};
216219

217220
class LLVMTuple {

src/libasr/pass/pass_manager.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ namespace LCompilers {
129129
"select_case",
130130
"inline_function_calls",
131131
"unused_functions",
132-
"pass_compare"
133132
};
134133

135134
_with_optimization_passes = {

0 commit comments

Comments
 (0)