Skip to content

Commit fd6ee57

Browse files
authored
LLVM: Refactor to use cleaner create_if_else (lcompilers#1742)
1 parent ce6cae0 commit fd6ee57

File tree

1 file changed

+30
-126
lines changed

1 file changed

+30
-126
lines changed

src/libasr/codegen/llvm_utils.cpp

Lines changed: 30 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -966,24 +966,16 @@ namespace LCompilers {
966966
llvm_utils->create_ptr_gep(src_key_mask, itr));
967967
LLVM::CreateStore(*builder, key_mask_value,
968968
llvm_utils->create_ptr_gep(dest_key_mask, itr));
969-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
970-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
971-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
972-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
973969
llvm::Value* is_key_set = builder->CreateICmpEQ(key_mask_value,
974970
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
975-
builder->CreateCondBr(is_key_set, thenBB, elseBB);
976-
builder->SetInsertPoint(thenBB);
977-
{
978971

972+
llvm_utils->create_if_else(is_key_set, [&]() {
979973
llvm::Value* srci = llvm_utils->create_ptr_gep(src_key_value_pairs, itr);
980974
llvm::Value* desti = llvm_utils->create_ptr_gep(dest_key_value_pairs, itr);
981975
deepcopy_key_value_pair_linked_list(srci, desti, dest_key_value_pairs,
982976
src_capacity, dict_type, module, name2memidx);
983-
}
984-
builder->CreateBr(mergeBB);
985-
llvm_utils->start_new_block(elseBB);
986-
llvm_utils->start_new_block(mergeBB);
977+
}, [=]() {
978+
});
987979
llvm::Value* tmp = builder->CreateAdd(
988980
itr,
989981
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
@@ -1004,17 +996,10 @@ namespace LCompilers {
1004996
llvm::Value* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
1005997
llvm::APInt(32, 0));
1006998

1007-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
1008-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
1009-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
1010-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
1011-
1012999
llvm::Value* cond = builder->CreateOr(
10131000
builder->CreateICmpSGE(pos, end_point),
10141001
builder->CreateICmpSLT(pos, zero));
1015-
builder->CreateCondBr(cond, thenBB, elseBB);
1016-
builder->SetInsertPoint(thenBB);
1017-
{
1002+
llvm_utils->create_if_else(cond, [&]() {
10181003
std::string index_error = "IndexError: %s%d%s%d\n",
10191004
message1 = "List index is out of range. Index range is (0, ",
10201005
message2 = "), but the given index is ";
@@ -1029,11 +1014,8 @@ namespace LCompilers {
10291014
llvm::Value *exit_code = llvm::ConstantInt::get(context,
10301015
llvm::APInt(32, exit_code_int));
10311016
exit(context, module, *builder, exit_code);
1032-
}
1033-
builder->CreateBr(mergeBB);
1034-
1035-
llvm_utils->start_new_block(elseBB);
1036-
llvm_utils->start_new_block(mergeBB);
1017+
}, [=]() {
1018+
});
10371019
}
10381020

10391021
void LLVMList::write_item(llvm::Value* list, llvm::Value* pos,
@@ -1181,26 +1163,17 @@ namespace LCompilers {
11811163
llvm::Value* is_key_matching = llvm::ConstantInt::get(llvm::Type::getInt1Ty(context),
11821164
llvm::APInt(1, 0));
11831165
LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var);
1184-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
1185-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
1186-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
1187-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
11881166
llvm::Value* compare_keys = builder->CreateAnd(is_key_set,
11891167
builder->CreateNot(is_key_skip));
1190-
builder->CreateCondBr(compare_keys, thenBB, elseBB);
1191-
builder->SetInsertPoint(thenBB);
1192-
{
1168+
llvm_utils->create_if_else(compare_keys, [&]() {
11931169
llvm::Value* original_key = llvm_utils->list_api->read_item(key_list, pos,
11941170
false, module, LLVM::is_llvm_struct(key_asr_type));
11951171
is_key_matching = llvm_utils->is_equal_by_value(key, original_key, module,
11961172
key_asr_type);
11971173
LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var);
1198-
}
1199-
builder->CreateBr(mergeBB);
1200-
1174+
}, [=]() {
1175+
});
12011176

1202-
llvm_utils->start_new_block(elseBB);
1203-
llvm_utils->start_new_block(mergeBB);
12041177
// TODO: Allow safe exit if pos becomes key_hash again.
12051178
// Ideally should not happen as dict will be resized once
12061179
// load factor touches a threshold (which will always be less than 1)
@@ -1269,25 +1242,16 @@ namespace LCompilers {
12691242
llvm::Value* is_key_matching = llvm::ConstantInt::get(llvm::Type::getInt1Ty(context),
12701243
llvm::APInt(1, 0));
12711244
LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var);
1272-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
1273-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
1274-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
1275-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
12761245
llvm::Value* compare_keys = builder->CreateAnd(is_key_set,
12771246
builder->CreateNot(is_key_skip));
1278-
builder->CreateCondBr(compare_keys, thenBB, elseBB);
1279-
builder->SetInsertPoint(thenBB);
1280-
{
1247+
llvm_utils->create_if_else(compare_keys, [&]() {
12811248
llvm::Value* original_key = llvm_utils->list_api->read_item(key_list, pos,
12821249
false, module, LLVM::is_llvm_struct(key_asr_type));
12831250
is_key_matching = llvm_utils->is_equal_by_value(key, original_key, module,
12841251
key_asr_type);
12851252
LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var);
1286-
}
1287-
builder->CreateBr(mergeBB);
1288-
1289-
llvm_utils->start_new_block(elseBB);
1290-
llvm_utils->start_new_block(mergeBB);
1253+
}, [=]() {
1254+
});
12911255
// TODO: Allow safe exit if pos becomes key_hash again.
12921256
// Ideally should not happen as dict will be resized once
12931257
// load factor touches a threshold (which will always be less than 1)
@@ -1371,20 +1335,11 @@ namespace LCompilers {
13711335
llvm::Value* break_signal = llvm_utils->is_equal_by_value(key, kv_key, module, key_asr_type);
13721336
break_signal = builder->CreateNot(break_signal);
13731337
LLVM::CreateStore(*builder, break_signal, is_key_matching_var);
1374-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
1375-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
1376-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
1377-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
1378-
builder->CreateCondBr(break_signal, thenBB, elseBB);
1379-
builder->SetInsertPoint(thenBB);
1380-
{
1338+
llvm_utils->create_if_else(break_signal, [&]() {
13811339
llvm::Value* next_kv_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 2));
13821340
LLVM::CreateStore(*builder, next_kv_struct, chain_itr);
1383-
}
1384-
builder->CreateBr(mergeBB);
1385-
1386-
llvm_utils->start_new_block(elseBB);
1387-
llvm_utils->start_new_block(mergeBB);
1341+
}, [=]() {
1342+
});
13881343
}
13891344

13901345
builder->CreateBr(loophead);
@@ -2201,22 +2156,14 @@ namespace LCompilers {
22012156
llvm::Value* itr = LLVM::CreateLoad(*builder, idx_ptr);
22022157
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
22032158
llvm_utils->create_ptr_gep(old_key_mask_value, itr));
2204-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
2205-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
2206-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
2207-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
22082159
llvm::Value* is_key_set = builder->CreateICmpEQ(key_mask_value,
22092160
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
2210-
builder->CreateCondBr(is_key_set, thenBB, elseBB);
2211-
builder->SetInsertPoint(thenBB);
2212-
{
22132161

2162+
llvm_utils->create_if_else(is_key_set, [&]() {
22142163
llvm::Value* srci = llvm_utils->create_ptr_gep(old_key_value_pairs_value, itr);
22152164
write_key_value_pair_linked_list(srci, dict, capacity, key_asr_type, value_asr_type, module, name2memidx);
2216-
}
2217-
builder->CreateBr(mergeBB);
2218-
llvm_utils->start_new_block(elseBB);
2219-
llvm_utils->start_new_block(mergeBB);
2165+
}, [=]() {
2166+
});
22202167
llvm::Value* tmp = builder->CreateAdd(
22212168
itr,
22222169
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
@@ -2260,11 +2207,6 @@ namespace LCompilers {
22602207
void LLVMDict::rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module,
22612208
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
22622209
std::map<std::string, std::map<std::string, int>>& name2memidx) {
2263-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
2264-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
2265-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
2266-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
2267-
22682210
llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict));
22692211
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
22702212
llvm::Value* rehash_condition = builder->CreateICmpEQ(capacity,
@@ -2278,26 +2220,16 @@ namespace LCompilers {
22782220
llvm::Value* load_factor_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context),
22792221
llvm::APFloat((float) 0.6));
22802222
rehash_condition = builder->CreateOr(rehash_condition, builder->CreateFCmpOGE(load_factor, load_factor_threshold));
2281-
builder->CreateCondBr(rehash_condition, thenBB, elseBB);
2282-
builder->SetInsertPoint(thenBB);
2283-
{
2223+
llvm_utils->create_if_else(rehash_condition, [&]() {
22842224
rehash(dict, module, key_asr_type, value_asr_type, name2memidx);
2285-
}
2286-
builder->CreateBr(mergeBB);
2287-
2288-
llvm_utils->start_new_block(elseBB);
2289-
llvm_utils->start_new_block(mergeBB);
2225+
}, [=]() {
2226+
});
22902227
}
22912228

22922229
void LLVMDictSeparateChaining::rehash_all_at_once_if_needed(
22932230
llvm::Value* dict, llvm::Module* module,
22942231
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
22952232
std::map<std::string, std::map<std::string, int>>& name2memidx) {
2296-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
2297-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
2298-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
2299-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
2300-
23012233
llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict));
23022234
llvm::Value* buckets_filled = LLVM::CreateLoad(*builder, get_pointer_to_number_of_filled_buckets(dict));
23032235
llvm::Value* rehash_condition = LLVM::CreateLoad(*builder, get_pointer_to_rehash_flag(dict));
@@ -2310,15 +2242,10 @@ namespace LCompilers {
23102242
llvm::APFloat((float) 2.0));
23112243
rehash_condition = builder->CreateAnd(rehash_condition,
23122244
builder->CreateFCmpOGE(avg_ll_length, avg_ll_length_threshold));
2313-
builder->CreateCondBr(rehash_condition, thenBB, elseBB);
2314-
builder->SetInsertPoint(thenBB);
2315-
{
2245+
llvm_utils->create_if_else(rehash_condition, [&]() {
23162246
rehash(dict, module, key_asr_type, value_asr_type, name2memidx);
2317-
}
2318-
builder->CreateBr(mergeBB);
2319-
2320-
llvm_utils->start_new_block(elseBB);
2321-
llvm_utils->start_new_block(mergeBB);
2247+
}, [=]() {
2248+
});
23222249
}
23232250

23242251
void LLVMDict::write_item(llvm::Value* dict, llvm::Value* key,
@@ -2756,17 +2683,9 @@ namespace LCompilers {
27562683
// end
27572684
llvm_utils->start_new_block(loopend);
27582685

2759-
2760-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
2761-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
2762-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
2763-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
2764-
27652686
llvm::Value* cond = builder->CreateICmpEQ(
27662687
LLVM::CreateLoad(*builder, i), current_end_point);
2767-
builder->CreateCondBr(cond, thenBB, elseBB);
2768-
builder->SetInsertPoint(thenBB);
2769-
{
2688+
llvm_utils->create_if_else(cond, [&]() {
27702689
std::string message = "The list does not contain the element: ";
27712690
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("ValueError: %s%d\n");
27722691
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
@@ -2775,12 +2694,8 @@ namespace LCompilers {
27752694
llvm::Value *exit_code = llvm::ConstantInt::get(context,
27762695
llvm::APInt(32, exit_code_int));
27772696
exit(context, module, *builder, exit_code);
2778-
}
2779-
builder->CreateBr(mergeBB);
2780-
2781-
llvm_utils->start_new_block(elseBB);
2782-
llvm_utils->start_new_block(mergeBB);
2783-
2697+
}, [=]() {
2698+
});
27842699
return LLVM::CreateLoad(*builder, i);
27852700
}
27862701

@@ -2831,27 +2746,16 @@ namespace LCompilers {
28312746
llvm_utils->start_new_block(loopbody);
28322747
{
28332748
// if occurrence found, increment cnt
2834-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
2835-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
2836-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
2837-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
2838-
28392749
llvm::Value* left_arg = read_item(list, LLVM::CreateLoad(*builder, i),
28402750
false, module, LLVM::is_llvm_struct(item_type));
28412751
llvm::Value* cond = llvm_utils->is_equal_by_value(left_arg, item, module, item_type);
2842-
builder->CreateCondBr(cond, thenBB, elseBB);
2843-
builder->SetInsertPoint(thenBB);
2844-
{
2752+
llvm_utils->create_if_else(cond, [&]() {
28452753
tmp = builder->CreateAdd(
28462754
LLVM::CreateLoad(*builder, cnt),
28472755
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
28482756
LLVM::CreateStore(*builder, tmp, cnt);
2849-
}
2850-
builder->CreateBr(mergeBB);
2851-
2852-
llvm_utils->start_new_block(elseBB);
2853-
llvm_utils->start_new_block(mergeBB);
2854-
2757+
}, [=]() {
2758+
});
28552759
// increment i
28562760
tmp = builder->CreateAdd(
28572761
LLVM::CreateLoad(*builder, i),

0 commit comments

Comments
 (0)