Skip to content

Commit ce5248a

Browse files
committed
LLVM: Consider default value in get
1 parent 8382484 commit ce5248a

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

src/libasr/codegen/llvm_utils.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,7 +1534,7 @@ namespace LCompilers {
15341534
llvm::Value* LLVMDict::resolve_collision_for_read(
15351535
llvm::Value* dict, llvm::Value* key_hash,
15361536
llvm::Value* key, llvm::Module& module,
1537-
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) {
1537+
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, llvm::Value* /*def_value*/) {
15381538
llvm::Value* key_list = get_key_list(dict);
15391539
llvm::Value* value_list = get_value_list(dict);
15401540
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
@@ -1548,7 +1548,7 @@ namespace LCompilers {
15481548
llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read(
15491549
llvm::Value* dict, llvm::Value* key_hash,
15501550
llvm::Value* key, llvm::Module& module,
1551-
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) {
1551+
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, llvm::Value *def_value) {
15521552
llvm::Value* key_list = get_key_list(dict);
15531553
llvm::Value* value_list = get_value_list(dict);
15541554
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
@@ -1601,6 +1601,10 @@ namespace LCompilers {
16011601
builder->CreateBr(mergeBB);
16021602
llvm_utils->start_new_block(elseBB);
16031603
{
1604+
if (def_value != nullptr) {
1605+
llvm_utils->start_new_block(mergeBB);
1606+
return def_value;
1607+
}
16041608
this->resolve_collision(capacity, key_hash, key, key_list, key_mask,
16051609
module, key_asr_type, true);
16061610
}
@@ -1614,7 +1618,7 @@ namespace LCompilers {
16141618
llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read(
16151619
llvm::Value* dict, llvm::Value* key_hash,
16161620
llvm::Value* key, llvm::Module& module,
1617-
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) {
1621+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Value *def_value) {
16181622
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
16191623
llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict));
16201624
llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash);
@@ -1658,6 +1662,10 @@ namespace LCompilers {
16581662
builder->CreateBr(mergeBB_single_match);
16591663
llvm_utils->start_new_block(elseBB_single_match);
16601664
{
1665+
if (def_value != nullptr) {
1666+
llvm_utils->start_new_block(mergeBB_single_match);
1667+
return def_value;
1668+
}
16611669
std::string message = "The dict does not contain the specified key";
16621670
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
16631671
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
@@ -2103,24 +2111,26 @@ namespace LCompilers {
21032111
}
21042112

21052113
llvm::Value* LLVMDict::read_item(llvm::Value* dict, llvm::Value* key,
2106-
llvm::Module& module, ASR::Dict_t* dict_type,
2114+
llvm::Module& module, ASR::Dict_t* dict_type, llvm::Value* def_value,
21072115
bool get_pointer) {
21082116
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
21092117
llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module);
21102118
llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module,
2111-
dict_type->m_key_type, dict_type->m_value_type);
2119+
dict_type->m_key_type, dict_type->m_value_type,
2120+
def_value);
21122121
if( get_pointer ) {
21132122
return value_ptr;
21142123
}
21152124
return LLVM::CreateLoad(*builder, value_ptr);
21162125
}
21172126

21182127
llvm::Value* LLVMDictSeparateChaining::read_item(llvm::Value* dict, llvm::Value* key,
2119-
llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer) {
2128+
llvm::Module& module, ASR::Dict_t* dict_type, llvm::Value* def_value, bool get_pointer) {
21202129
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
21212130
llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module);
21222131
llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module,
2123-
dict_type->m_key_type, dict_type->m_value_type);
2132+
dict_type->m_key_type, dict_type->m_value_type,
2133+
def_value);
21242134
std::pair<std::string, std::string> llvm_key = std::make_pair(
21252135
ASRUtils::get_type_code(dict_type->m_key_type),
21262136
ASRUtils::get_type_code(dict_type->m_value_type)

src/libasr/codegen/llvm_utils.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ namespace LCompilers {
9595
return ASR::is_a<ASR::Tuple_t>(*asr_type) ||
9696
ASR::is_a<ASR::List_t>(*asr_type) ||
9797
ASR::is_a<ASR::Struct_t>(*asr_type) ||
98-
ASR::is_a<ASR::Class_t>(*asr_type)||
98+
ASR::is_a<ASR::Class_t>(*asr_type)||
9999
ASR::is_a<ASR::Dict_t>(*asr_type);
100100
}
101101
}
@@ -345,7 +345,7 @@ namespace LCompilers {
345345
virtual
346346
llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash,
347347
llvm::Value* key, llvm::Module& module,
348-
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) = 0;
348+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Value *def_value) = 0;
349349

350350
virtual
351351
void rehash(llvm::Value* dict, llvm::Module* module,
@@ -367,7 +367,7 @@ namespace LCompilers {
367367

368368
virtual
369369
llvm::Value* read_item(llvm::Value* dict, llvm::Value* key,
370-
llvm::Module& module, ASR::Dict_t* dict_type,
370+
llvm::Module& module, ASR::Dict_t* dict_type, llvm::Value* def_value=nullptr,
371371
bool get_pointer=false) = 0;
372372

373373
virtual
@@ -436,7 +436,8 @@ namespace LCompilers {
436436

437437
llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash,
438438
llvm::Value* key, llvm::Module& module,
439-
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type);
439+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
440+
llvm::Value *def_value=nullptr);
440441

441442
void rehash(llvm::Value* dict, llvm::Module* module,
442443
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
@@ -454,7 +455,7 @@ namespace LCompilers {
454455
std::map<std::string, std::map<std::string, int>>& name2memidx);
455456

456457
llvm::Value* read_item(llvm::Value* dict, llvm::Value* key,
457-
llvm::Module& module, ASR::Dict_t* key_asr_type,
458+
llvm::Module& module, ASR::Dict_t* key_asr_type, llvm::Value* def_value=nullptr,
458459
bool get_pointer=false);
459460

460461
llvm::Value* pop_item(llvm::Value* dict, llvm::Value* key,
@@ -494,7 +495,8 @@ namespace LCompilers {
494495

495496
llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash,
496497
llvm::Value* key, llvm::Module& module,
497-
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type);
498+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
499+
llvm::Value *def_value=nullptr);
498500

499501
virtual ~LLVMDictOptimizedLinearProbing();
500502

@@ -562,7 +564,8 @@ namespace LCompilers {
562564

563565
llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash,
564566
llvm::Value* key, llvm::Module& module,
565-
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type);
567+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
568+
llvm::Value *def_value=nullptr);
566569

567570
void rehash(llvm::Value* dict, llvm::Module* module,
568571
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
@@ -580,7 +583,7 @@ namespace LCompilers {
580583
std::map<std::string, std::map<std::string, int>>& name2memidx);
581584

582585
llvm::Value* read_item(llvm::Value* dict, llvm::Value* key,
583-
llvm::Module& module, ASR::Dict_t* dict_type,
586+
llvm::Module& module, ASR::Dict_t* dict_type, llvm::Value* def_value,
584587
bool get_pointer=false);
585588

586589
llvm::Value* pop_item(llvm::Value* dict, llvm::Value* key,

0 commit comments

Comments
 (0)