diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 3072540619..8250d26b38 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -966,24 +966,16 @@ namespace LCompilers { llvm_utils->create_ptr_gep(src_key_mask, itr)); LLVM::CreateStore(*builder, key_mask_value, llvm_utils->create_ptr_gep(dest_key_mask, itr)); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* is_key_set = builder->CreateICmpEQ(key_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); - builder->CreateCondBr(is_key_set, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(is_key_set, [&]() { llvm::Value* srci = llvm_utils->create_ptr_gep(src_key_value_pairs, itr); llvm::Value* desti = llvm_utils->create_ptr_gep(dest_key_value_pairs, itr); deepcopy_key_value_pair_linked_list(srci, desti, dest_key_value_pairs, src_capacity, dict_type, module, name2memidx); - } - builder->CreateBr(mergeBB); - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); + }, [=]() { + }); llvm::Value* tmp = builder->CreateAdd( itr, llvm::ConstantInt::get(context, llvm::APInt(32, 1))); @@ -1004,17 +996,10 @@ namespace LCompilers { llvm::Value* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - llvm::Value* cond = builder->CreateOr( builder->CreateICmpSGE(pos, end_point), builder->CreateICmpSLT(pos, zero)); - builder->CreateCondBr(cond, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(cond, [&]() { std::string index_error = "IndexError: %s%d%s%d\n", message1 = "List index is out of range. Index range is (0, ", message2 = "), but the given index is "; @@ -1029,11 +1014,8 @@ namespace LCompilers { llvm::Value *exit_code = llvm::ConstantInt::get(context, llvm::APInt(32, exit_code_int)); exit(context, module, *builder, exit_code); - } - builder->CreateBr(mergeBB); - - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); + }, [=]() { + }); } void LLVMList::write_item(llvm::Value* list, llvm::Value* pos, @@ -1181,26 +1163,17 @@ namespace LCompilers { llvm::Value* is_key_matching = llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 0)); LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* compare_keys = builder->CreateAnd(is_key_set, builder->CreateNot(is_key_skip)); - builder->CreateCondBr(compare_keys, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(compare_keys, [&]() { llvm::Value* original_key = llvm_utils->list_api->read_item(key_list, pos, false, module, LLVM::is_llvm_struct(key_asr_type)); is_key_matching = llvm_utils->is_equal_by_value(key, original_key, module, key_asr_type); LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var); - } - builder->CreateBr(mergeBB); - + }, [=]() { + }); - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); // TODO: Allow safe exit if pos becomes key_hash again. // Ideally should not happen as dict will be resized once // load factor touches a threshold (which will always be less than 1) @@ -1269,25 +1242,16 @@ namespace LCompilers { llvm::Value* is_key_matching = llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 0)); LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* compare_keys = builder->CreateAnd(is_key_set, builder->CreateNot(is_key_skip)); - builder->CreateCondBr(compare_keys, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(compare_keys, [&]() { llvm::Value* original_key = llvm_utils->list_api->read_item(key_list, pos, false, module, LLVM::is_llvm_struct(key_asr_type)); is_key_matching = llvm_utils->is_equal_by_value(key, original_key, module, key_asr_type); LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var); - } - builder->CreateBr(mergeBB); - - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); + }, [=]() { + }); // TODO: Allow safe exit if pos becomes key_hash again. // Ideally should not happen as dict will be resized once // load factor touches a threshold (which will always be less than 1) @@ -1371,20 +1335,11 @@ namespace LCompilers { llvm::Value* break_signal = llvm_utils->is_equal_by_value(key, kv_key, module, key_asr_type); break_signal = builder->CreateNot(break_signal); LLVM::CreateStore(*builder, break_signal, is_key_matching_var); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - builder->CreateCondBr(break_signal, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(break_signal, [&]() { llvm::Value* next_kv_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 2)); LLVM::CreateStore(*builder, next_kv_struct, chain_itr); - } - builder->CreateBr(mergeBB); - - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); + }, [=]() { + }); } builder->CreateBr(loophead); @@ -2201,22 +2156,14 @@ namespace LCompilers { llvm::Value* itr = LLVM::CreateLoad(*builder, idx_ptr); llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(old_key_mask_value, itr)); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* is_key_set = builder->CreateICmpEQ(key_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); - builder->CreateCondBr(is_key_set, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(is_key_set, [&]() { llvm::Value* srci = llvm_utils->create_ptr_gep(old_key_value_pairs_value, itr); write_key_value_pair_linked_list(srci, dict, capacity, key_asr_type, value_asr_type, module, name2memidx); - } - builder->CreateBr(mergeBB); - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); + }, [=]() { + }); llvm::Value* tmp = builder->CreateAdd( itr, llvm::ConstantInt::get(context, llvm::APInt(32, 1))); @@ -2260,11 +2207,6 @@ namespace LCompilers { void LLVMDict::rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, std::map>& name2memidx) { - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict)); llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* rehash_condition = builder->CreateICmpEQ(capacity, @@ -2278,26 +2220,16 @@ namespace LCompilers { llvm::Value* load_factor_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context), llvm::APFloat((float) 0.6)); rehash_condition = builder->CreateOr(rehash_condition, builder->CreateFCmpOGE(load_factor, load_factor_threshold)); - builder->CreateCondBr(rehash_condition, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(rehash_condition, [&]() { rehash(dict, module, key_asr_type, value_asr_type, name2memidx); - } - builder->CreateBr(mergeBB); - - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); + }, [=]() { + }); } void LLVMDictSeparateChaining::rehash_all_at_once_if_needed( llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, std::map>& name2memidx) { - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict)); llvm::Value* buckets_filled = LLVM::CreateLoad(*builder, get_pointer_to_number_of_filled_buckets(dict)); llvm::Value* rehash_condition = LLVM::CreateLoad(*builder, get_pointer_to_rehash_flag(dict)); @@ -2310,15 +2242,10 @@ namespace LCompilers { llvm::APFloat((float) 2.0)); rehash_condition = builder->CreateAnd(rehash_condition, builder->CreateFCmpOGE(avg_ll_length, avg_ll_length_threshold)); - builder->CreateCondBr(rehash_condition, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(rehash_condition, [&]() { rehash(dict, module, key_asr_type, value_asr_type, name2memidx); - } - builder->CreateBr(mergeBB); - - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); + }, [=]() { + }); } void LLVMDict::write_item(llvm::Value* dict, llvm::Value* key, @@ -2756,17 +2683,9 @@ namespace LCompilers { // end llvm_utils->start_new_block(loopend); - - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - llvm::Value* cond = builder->CreateICmpEQ( LLVM::CreateLoad(*builder, i), current_end_point); - builder->CreateCondBr(cond, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(cond, [&]() { std::string message = "The list does not contain the element: "; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("ValueError: %s%d\n"); llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); @@ -2775,12 +2694,8 @@ namespace LCompilers { llvm::Value *exit_code = llvm::ConstantInt::get(context, llvm::APInt(32, exit_code_int)); exit(context, module, *builder, exit_code); - } - builder->CreateBr(mergeBB); - - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); - + }, [=]() { + }); return LLVM::CreateLoad(*builder, i); } @@ -2831,27 +2746,16 @@ namespace LCompilers { llvm_utils->start_new_block(loopbody); { // if occurrence found, increment cnt - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - llvm::Value* left_arg = read_item(list, LLVM::CreateLoad(*builder, i), false, module, LLVM::is_llvm_struct(item_type)); llvm::Value* cond = llvm_utils->is_equal_by_value(left_arg, item, module, item_type); - builder->CreateCondBr(cond, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(cond, [&]() { tmp = builder->CreateAdd( LLVM::CreateLoad(*builder, cnt), llvm::ConstantInt::get(context, llvm::APInt(32, 1))); LLVM::CreateStore(*builder, tmp, cnt); - } - builder->CreateBr(mergeBB); - - llvm_utils->start_new_block(elseBB); - llvm_utils->start_new_block(mergeBB); - + }, [=]() { + }); // increment i tmp = builder->CreateAdd( LLVM::CreateLoad(*builder, i),