Skip to content

Commit 3f371f2

Browse files
committed
Following changes have been made,
1. Rename 'linear_probing_*' to 'resolve_collision_*' 2. Implemented 'resolve_collision_*' methods for LLVMDictOptimizedLinearProbing
1 parent 323ecc7 commit 3f371f2

File tree

1 file changed

+150
-18
lines changed

1 file changed

+150
-18
lines changed

src/libasr/codegen/llvm_utils.cpp

Lines changed: 150 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,13 @@ namespace LFortran {
286286
is_dict_present(false) {
287287
}
288288

289+
LLVMDictOptimizedLinearProbing::LLVMDictOptimizedLinearProbing(
290+
llvm::LLVMContext& context_,
291+
LLVMUtils* llvm_utils_,
292+
llvm::IRBuilder<>* builder_):
293+
LLVMDict(context_, llvm_utils_, builder_) {
294+
}
295+
289296
llvm::Type* LLVMList::get_list_type(llvm::Type* el_type, std::string& type_code,
290297
int32_t type_size) {
291298
if( typecode2listtype.find(type_code) != typecode2listtype.end() ) {
@@ -577,10 +584,86 @@ namespace LFortran {
577584
are_iterators_set = false;
578585
}
579586

580-
void LLVMDict::linear_probing(llvm::Value* capacity, llvm::Value* key_hash,
581-
llvm::Value* key, llvm::Value* key_list,
582-
llvm::Value* key_mask, llvm::Module& module,
583-
ASR::ttype_t* key_asr_type, bool for_read) {
587+
void LLVMDict::resolve_collision(
588+
llvm::Value* capacity, llvm::Value* key_hash,
589+
llvm::Value* key, llvm::Value* key_list,
590+
llvm::Value* key_mask, llvm::Module& module,
591+
ASR::ttype_t* key_asr_type, bool /*for_read*/) {
592+
if( !are_iterators_set ) {
593+
pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
594+
is_key_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
595+
}
596+
LLVM::CreateStore(*builder, key_hash, pos_ptr);
597+
598+
599+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
600+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
601+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
602+
603+
604+
// head
605+
llvm_utils->start_new_block(loophead);
606+
{
607+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
608+
llvm::Value* is_key_set = LLVM::CreateLoad(*builder,
609+
llvm_utils->create_ptr_gep(key_mask, pos));
610+
is_key_set = builder->CreateICmpNE(is_key_set,
611+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)));
612+
llvm::Value* is_key_matching = llvm::ConstantInt::get(llvm::Type::getInt1Ty(context),
613+
llvm::APInt(1, 0));
614+
LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var);
615+
llvm::Function *fn = builder->GetInsertBlock()->getParent();
616+
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
617+
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
618+
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
619+
builder->CreateCondBr(is_key_set, thenBB, elseBB);
620+
builder->SetInsertPoint(thenBB);
621+
{
622+
llvm::Value* original_key = llvm_utils->list_api->read_item(key_list, pos,
623+
LLVM::is_llvm_struct(key_asr_type), false);
624+
is_key_matching = llvm_utils->is_equal_by_value(key, original_key, module,
625+
key_asr_type);
626+
LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var);
627+
}
628+
builder->CreateBr(mergeBB);
629+
630+
631+
llvm_utils->start_new_block(elseBB);
632+
llvm_utils->start_new_block(mergeBB);
633+
// TODO: Allow safe exit if pos becomes key_hash again.
634+
// Ideally should not happen as dict will be resized once
635+
// load factor touches a threshold (which will always be less than 1)
636+
// so there will be some key which will not be set. However for safety
637+
// we can add an exit from the loop with a error message.
638+
llvm::Value *cond = builder->CreateAnd(is_key_set, builder->CreateNot(
639+
LLVM::CreateLoad(*builder, is_key_matching_var)));
640+
builder->CreateCondBr(cond, loopbody, loopend);
641+
}
642+
643+
644+
// body
645+
llvm_utils->start_new_block(loopbody);
646+
{
647+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
648+
pos = builder->CreateAdd(pos, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
649+
llvm::APInt(32, 1)));
650+
pos = builder->CreateSRem(pos, capacity);
651+
LLVM::CreateStore(*builder, pos, pos_ptr);
652+
}
653+
654+
655+
builder->CreateBr(loophead);
656+
657+
658+
// end
659+
llvm_utils->start_new_block(loopend);
660+
}
661+
662+
void LLVMDictOptimizedLinearProbing::resolve_collision(
663+
llvm::Value* capacity, llvm::Value* key_hash,
664+
llvm::Value* key, llvm::Value* key_list,
665+
llvm::Value* key_mask, llvm::Module& module,
666+
ASR::ttype_t* key_asr_type, bool for_read) {
584667
if( !are_iterators_set ) {
585668
if( !for_read ) {
586669
pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
@@ -648,15 +731,45 @@ namespace LFortran {
648731
llvm_utils->start_new_block(loopend);
649732
}
650733

651-
void LLVMDict::linear_probing_for_write(llvm::Value* dict, llvm::Value* key_hash,
652-
llvm::Value* key, llvm::Value* value,
653-
llvm::Module& module, ASR::ttype_t* key_asr_type,
654-
ASR::ttype_t* value_asr_type) {
734+
void LLVMDict::resolve_collision_for_write(
735+
llvm::Value* dict, llvm::Value* key_hash,
736+
llvm::Value* key, llvm::Value* value,
737+
llvm::Module& module, ASR::ttype_t* key_asr_type,
738+
ASR::ttype_t* value_asr_type) {
655739
llvm::Value* key_list = get_key_list(dict);
656740
llvm::Value* value_list = get_value_list(dict);
657741
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
658742
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
659-
linear_probing(capacity, key_hash, key, key_list, key_mask, module, key_asr_type);
743+
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type);
744+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
745+
llvm_utils->list_api->write_item(key_list, pos, key,
746+
key_asr_type, module, false);
747+
llvm_utils->list_api->write_item(value_list, pos, value,
748+
value_asr_type, module, false);
749+
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
750+
llvm_utils->create_ptr_gep(key_mask, pos));
751+
llvm::Value* is_slot_empty = builder->CreateICmpEQ(key_mask_value,
752+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)));
753+
llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict);
754+
is_slot_empty = builder->CreateZExt(is_slot_empty, llvm::Type::getInt32Ty(context));
755+
llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr);
756+
LLVM::CreateStore(*builder, builder->CreateAdd(occupancy, is_slot_empty),
757+
occupancy_ptr);
758+
LLVM::CreateStore(*builder,
759+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)),
760+
llvm_utils->create_ptr_gep(key_mask, pos));
761+
}
762+
763+
void LLVMDictOptimizedLinearProbing::resolve_collision_for_write(
764+
llvm::Value* dict, llvm::Value* key_hash,
765+
llvm::Value* key, llvm::Value* value,
766+
llvm::Module& module, ASR::ttype_t* key_asr_type,
767+
ASR::ttype_t* value_asr_type) {
768+
llvm::Value* key_list = get_key_list(dict);
769+
llvm::Value* value_list = get_value_list(dict);
770+
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
771+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
772+
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type);
660773
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
661774
llvm_utils->list_api->write_item(key_list, pos, key,
662775
key_asr_type, module, false);
@@ -681,9 +794,24 @@ namespace LFortran {
681794
LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(key_mask, pos));
682795
}
683796

684-
llvm::Value* LLVMDict::linear_probing_for_read(llvm::Value* dict, llvm::Value* key_hash,
685-
llvm::Value* key, llvm::Module& module,
686-
ASR::ttype_t* key_asr_type) {
797+
llvm::Value* LLVMDict::resolve_collision_for_read(
798+
llvm::Value* dict, llvm::Value* key_hash,
799+
llvm::Value* key, llvm::Module& module,
800+
ASR::ttype_t* key_asr_type) {
801+
llvm::Value* key_list = get_key_list(dict);
802+
llvm::Value* value_list = get_value_list(dict);
803+
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
804+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
805+
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type);
806+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
807+
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, true, false);
808+
return item;
809+
}
810+
811+
llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read(
812+
llvm::Value* dict, llvm::Value* key_hash,
813+
llvm::Value* key, llvm::Module& module,
814+
ASR::ttype_t* key_asr_type) {
687815
llvm::Value* key_list = get_key_list(dict);
688816
llvm::Value* value_list = get_value_list(dict);
689817
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
@@ -707,7 +835,7 @@ namespace LFortran {
707835
builder->CreateBr(mergeBB);
708836
llvm_utils->start_new_block(elseBB);
709837
{
710-
linear_probing(capacity, key_hash, key, key_list, key_mask,
838+
this->resolve_collision(capacity, key_hash, key, key_list, key_mask,
711839
module, key_asr_type, true);
712840
}
713841
llvm_utils->start_new_block(mergeBB);
@@ -809,7 +937,7 @@ namespace LFortran {
809937
llvm::Value* value = llvm_utils->list_api->read_item(value_list, idx,
810938
LLVM::is_llvm_struct(value_asr_type), false);
811939
llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module);
812-
linear_probing(current_capacity, key_hash, key, new_key_list,
940+
this->resolve_collision(current_capacity, key_hash, key, new_key_list,
813941
new_key_mask, *module, key_asr_type);
814942
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
815943
llvm::Value* key_dest = llvm_utils->list_api->read_item(new_key_list, pos,
@@ -883,17 +1011,17 @@ namespace LFortran {
8831011
rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type);
8841012
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
8851013
llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module);
886-
linear_probing_for_write(dict, key_hash, key, value, *module,
887-
key_asr_type, value_asr_type);
1014+
this->resolve_collision_for_write(dict, key_hash, key, value, *module,
1015+
key_asr_type, value_asr_type);
8881016
}
8891017

8901018
llvm::Value* LLVMDict::read_item(llvm::Value* dict, llvm::Value* key,
8911019
llvm::Module& module, ASR::ttype_t* key_asr_type,
8921020
bool get_pointer) {
8931021
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
8941022
llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, module);
895-
llvm::Value* value_ptr = linear_probing_for_read(dict, key_hash, key, module,
896-
key_asr_type);
1023+
llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module,
1024+
key_asr_type);
8971025
if( get_pointer ) {
8981026
return value_ptr;
8991027
}
@@ -921,6 +1049,10 @@ namespace LFortran {
9211049
return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict));
9221050
}
9231051

1052+
LLVMDict::~LLVMDict() {}
1053+
1054+
LLVMDictOptimizedLinearProbing::~LLVMDictOptimizedLinearProbing() {}
1055+
9241056
void LLVMList::resize_if_needed(llvm::Value* list, llvm::Value* n,
9251057
llvm::Value* capacity, int32_t type_size,
9261058
llvm::Type* el_type, llvm::Module& module) {

0 commit comments

Comments
 (0)