From a9ab3509e838b8bd1095a5842f3eb8a95efc5603 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Aug 2022 12:41:22 +0530 Subject: [PATCH 1/7] Verify length of final dict after all insertions --- integration_tests/test_dict_01.py | 1 + 1 file changed, 1 insertion(+) diff --git a/integration_tests/test_dict_01.py b/integration_tests/test_dict_01.py index 27301eee4e..04613fd81d 100644 --- a/integration_tests/test_dict_01.py +++ b/integration_tests/test_dict_01.py @@ -12,5 +12,6 @@ def test_dict(): assert abs(rollnumber2cpi[i] - i/100.0 - 5.0) <= 1e-12 assert abs(rollnumber2cpi[0] - 1.1) <= 1e-12 + assert len(rollnumber2cpi) == 1001 test_dict() From 429b08f543a4a6d2ae4b5d94989c2318045c948b Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Aug 2022 12:41:48 +0530 Subject: [PATCH 2/7] Following changes have been made, 1. Create pos_ptr iterator only if linear probing is done writing to a dict --- src/libasr/codegen/llvm_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index b9c04f5fbd..867167f770 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -259,7 +259,7 @@ namespace LFortran { void linear_probing(llvm::Value* capacity, llvm::Value* key_hash, llvm::Value* key, llvm::Value* key_list, llvm::Value* key_mask, llvm::Module& module, - ASR::ttype_t* key_asr_type); + ASR::ttype_t* key_asr_type, bool for_read=false); void linear_probing_for_write(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Value* value, From a57a02b8e06576730ca2978943ac92a972e563a9 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Aug 2022 12:42:41 +0530 Subject: [PATCH 3/7] Following changes have been made, Use key_mask to figure out whether probing is needed for reading a value from the dict. --- src/libasr/codegen/llvm_utils.cpp | 100 +++++++++++++++++++++--------- 1 file changed, 69 insertions(+), 31 deletions(-) diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 480c6297ec..1de8d6421d 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -314,7 +314,7 @@ namespace LFortran { value_type_code, value_type_size); std::vector dict_type_vec = {llvm::Type::getInt32Ty(context), key_list_type, value_list_type, - llvm::Type::getInt1PtrTy(context)}; + llvm::Type::getInt8PtrTy(context)}; llvm::Type* dict_desc = llvm::StructType::create(context, dict_type_vec, "dict"); typecode2dicttype[llvm_key] = std::make_tuple(dict_desc, std::make_pair(key_type_size, value_type_size), @@ -403,14 +403,13 @@ namespace LFortran { llvm_utils->list_api->list_init(value_type_code, value_list, *module, initial_capacity, initial_capacity); llvm::DataLayout data_layout(module); - size_t bool_size = data_layout.getTypeAllocSize(llvm::Type::getInt1Ty(context)); + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); llvm::Value* llvm_capacity = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, initial_capacity)); - llvm::Value* llvm_bool_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), - llvm::APInt(32, bool_size)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); llvm::Value* key_mask = LLVM::lfortran_calloc(context, *module, *builder, llvm_capacity, - llvm_bool_size); - key_mask = builder->CreateBitCast(key_mask, llvm::Type::getInt1PtrTy(context)); + llvm_mask_size); LLVM::CreateStore(*builder, key_mask, get_pointer_to_keymask(dict)); } @@ -514,15 +513,14 @@ namespace LFortran { llvm::Value* src_key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(src)); llvm::Value* dest_key_mask_ptr = get_pointer_to_keymask(dest); llvm::DataLayout data_layout(module); - size_t bool_size = data_layout.getTypeAllocSize(llvm::Type::getInt1Ty(context)); - llvm::Value* llvm_bool_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), - llvm::APInt(32, bool_size)); + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); llvm::Value* src_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(src)); llvm::Value* dest_key_mask = LLVM::lfortran_calloc(context, *module, *builder, src_capacity, - llvm_bool_size); - dest_key_mask = builder->CreateBitCast(dest_key_mask, llvm::Type::getInt1PtrTy(context)); + llvm_mask_size); builder->CreateMemCpy(dest_key_mask, llvm::MaybeAlign(), src_key_mask, - llvm::MaybeAlign(), builder->CreateMul(src_capacity, llvm_bool_size)); + llvm::MaybeAlign(), builder->CreateMul(src_capacity, llvm_mask_size)); LLVM::CreateStore(*builder, dest_key_mask, dest_key_mask_ptr); } @@ -582,11 +580,14 @@ namespace LFortran { void LLVMDict::linear_probing(llvm::Value* capacity, llvm::Value* key_hash, llvm::Value* key, llvm::Value* key_list, llvm::Value* key_mask, llvm::Module& module, - ASR::ttype_t* key_asr_type) { + ASR::ttype_t* key_asr_type, bool for_read) { if( !are_iterators_set ) { - pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + if( !for_read ) { + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } is_key_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); } + LLVM::CreateStore(*builder, key_hash, pos_ptr); llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); @@ -599,8 +600,10 @@ namespace LFortran { llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm::Value* is_key_set = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, pos)); + is_key_set = builder->CreateICmpNE(is_key_set, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); llvm::Value* is_key_matching = llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), - llvm::APInt(1, 0)); + llvm::APInt(1, 0)); LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var); llvm::Function *fn = builder->GetInsertBlock()->getParent(); llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); @@ -624,7 +627,8 @@ namespace LFortran { // load factor touches a threshold (which will always be less than 1) // so there will be some key which will not be set. However for safety // we can add an exit from the loop with a error message. - llvm::Value *cond = builder->CreateAnd(is_key_set, builder->CreateNot(LLVM::CreateLoad(*builder, is_key_matching_var))); + llvm::Value *cond = builder->CreateAnd(is_key_set, builder->CreateNot( + LLVM::CreateLoad(*builder, is_key_matching_var))); builder->CreateCondBr(cond, loopbody, loopend); } @@ -633,7 +637,7 @@ namespace LFortran { { llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); pos = builder->CreateAdd(pos, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), - llvm::APInt(32, 1))); + llvm::APInt(32, 1))); pos = builder->CreateSRem(pos, capacity); LLVM::CreateStore(*builder, pos, pos_ptr); } @@ -658,16 +662,23 @@ namespace LFortran { key_asr_type, module, false); llvm_utils->list_api->write_item(value_list, pos, value, value_asr_type, module, false); - llvm::Value* is_slot_empty = builder->CreateNot(LLVM::CreateLoad(*builder, - llvm_utils->create_ptr_gep(key_mask, pos))); + + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, pos)); + llvm::Value* is_slot_empty = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); is_slot_empty = builder->CreateZExt(is_slot_empty, llvm::Type::getInt32Ty(context)); llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); LLVM::CreateStore(*builder, builder->CreateAdd(occupancy, is_slot_empty), occupancy_ptr); - LLVM::CreateStore(*builder, - llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 1)), - llvm_utils->create_ptr_gep(key_mask, pos)); + + llvm::Value* linear_prob_happened = builder->CreateICmpNE(key_hash, pos); + llvm::Value* set_max_2 = builder->CreateSelect(linear_prob_happened, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 2)), + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(key_mask, key_hash)); + LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(key_mask, pos)); } llvm::Value* LLVMDict::linear_probing_for_read(llvm::Value* dict, llvm::Value* key_hash, @@ -677,7 +688,29 @@ namespace LFortran { llvm::Value* value_list = get_value_list(dict); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); - linear_probing(capacity, key_hash, key, key_list, key_mask, module, key_asr_type); + if( !are_iterators_set ) { + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, key_hash)); + llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + LLVM::CreateStore(*builder, key_hash, pos_ptr); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + linear_probing(capacity, key_hash, key, key_list, key_mask, + module, key_asr_type, true); + } + llvm_utils->start_new_block(mergeBB); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, true, false); return item; @@ -733,12 +766,11 @@ namespace LFortran { llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); llvm::DataLayout data_layout(module); - size_t bool_size = data_layout.getTypeAllocSize(llvm::Type::getInt1Ty(context)); - llvm::Value* llvm_bool_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), - llvm::APInt(32, bool_size)); + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); llvm::Value* new_key_mask = LLVM::lfortran_calloc(context, *module, *builder, capacity, - llvm_bool_size); - new_key_mask = builder->CreateBitCast(new_key_mask, llvm::Type::getInt1PtrTy(context)); + llvm_mask_size); llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); if( !are_iterators_set ) { @@ -767,6 +799,8 @@ namespace LFortran { llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* is_key_set = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, idx)); + is_key_set = builder->CreateICmpNE(is_key_set, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); builder->CreateCondBr(is_key_set, thenBB, elseBB); builder->SetInsertPoint(thenBB); { @@ -784,9 +818,13 @@ namespace LFortran { llvm::Value* value_dest = llvm_utils->list_api->read_item(new_value_list, pos, true, false); llvm_utils->deepcopy(value, value_dest, value_asr_type, *module); - llvm::Value* key_mask_dest = llvm_utils->create_ptr_gep(new_key_mask, pos); - LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), - llvm::APInt(1, 1)), key_mask_dest); + + llvm::Value* linear_prob_happened = builder->CreateICmpNE(key_hash, pos); + llvm::Value* set_max_2 = builder->CreateSelect(linear_prob_happened, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 2)), + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(new_key_mask, key_hash)); + LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(new_key_mask, pos)); } builder->CreateBr(mergeBB); From 323ecc74e9f857794b2eaf91c3c0c2bac6fbd264 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Aug 2022 15:22:35 +0530 Subject: [PATCH 4/7] Added LLVMDictOptimizedLinearProbing --- src/libasr/codegen/llvm_utils.h | 42 +++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 867167f770..2b77113316 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -217,7 +217,7 @@ namespace LFortran { }; class LLVMDict { - private: + protected: llvm::LLVMContext& context; LLVMUtils* llvm_utils; @@ -256,17 +256,20 @@ namespace LFortran { llvm::Value* get_key_hash(llvm::Value* capacity, llvm::Value* key, ASR::ttype_t* key_asr_type, llvm::Module& module); - void linear_probing(llvm::Value* capacity, llvm::Value* key_hash, + virtual + void resolve_collision(llvm::Value* capacity, llvm::Value* key_hash, llvm::Value* key, llvm::Value* key_list, llvm::Value* key_mask, llvm::Module& module, ASR::ttype_t* key_asr_type, bool for_read=false); - void linear_probing_for_write(llvm::Value* dict, llvm::Value* key_hash, + virtual + void resolve_collision_for_write(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Value* value, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); - llvm::Value* linear_probing_for_read(llvm::Value* dict, llvm::Value* key_hash, + virtual + llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type); @@ -297,6 +300,37 @@ namespace LFortran { ASR::Dict_t* dict_type, llvm::Module* module); llvm::Value* len(llvm::Value* dict); + + virtual ~LLVMDict(); + }; + + class LLVMDictOptimizedLinearProbing: public LLVMDict { + + public: + + LLVMDictOptimizedLinearProbing(llvm::LLVMContext& context_, + LLVMUtils* llvm_utils, + llvm::IRBuilder<>* builder); + + virtual + void resolve_collision(llvm::Value* capacity, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* key_list, + llvm::Value* key_mask, llvm::Module& module, + ASR::ttype_t* key_asr_type, bool for_read=false); + + virtual + void resolve_collision_for_write(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* value, + llvm::Module& module, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type); + + virtual + llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type); + + virtual ~LLVMDictOptimizedLinearProbing(); + }; } // LFortran From 3f371f228bbcc372554291c885b2d9f38824ed15 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Aug 2022 15:22:54 +0530 Subject: [PATCH 5/7] Following changes have been made, 1. Rename 'linear_probing_*' to 'resolve_collision_*' 2. Implemented 'resolve_collision_*' methods for LLVMDictOptimizedLinearProbing --- src/libasr/codegen/llvm_utils.cpp | 168 ++++++++++++++++++++++++++---- 1 file changed, 150 insertions(+), 18 deletions(-) diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 1de8d6421d..f08f55bade 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -286,6 +286,13 @@ namespace LFortran { is_dict_present(false) { } + LLVMDictOptimizedLinearProbing::LLVMDictOptimizedLinearProbing( + llvm::LLVMContext& context_, + LLVMUtils* llvm_utils_, + llvm::IRBuilder<>* builder_): + LLVMDict(context_, llvm_utils_, builder_) { + } + llvm::Type* LLVMList::get_list_type(llvm::Type* el_type, std::string& type_code, int32_t type_size) { if( typecode2listtype.find(type_code) != typecode2listtype.end() ) { @@ -577,10 +584,86 @@ namespace LFortran { are_iterators_set = false; } - void LLVMDict::linear_probing(llvm::Value* capacity, llvm::Value* key_hash, - llvm::Value* key, llvm::Value* key_list, - llvm::Value* key_mask, llvm::Module& module, - ASR::ttype_t* key_asr_type, bool for_read) { + void LLVMDict::resolve_collision( + llvm::Value* capacity, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* key_list, + llvm::Value* key_mask, llvm::Module& module, + ASR::ttype_t* key_asr_type, bool /*for_read*/) { + if( !are_iterators_set ) { + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + is_key_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + } + LLVM::CreateStore(*builder, key_hash, pos_ptr); + + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* is_key_set = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, pos)); + is_key_set = builder->CreateICmpNE(is_key_set, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + llvm::Value* is_key_matching = llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), + llvm::APInt(1, 0)); + LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var); + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + builder->CreateCondBr(is_key_set, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + llvm::Value* original_key = llvm_utils->list_api->read_item(key_list, pos, + LLVM::is_llvm_struct(key_asr_type), false); + is_key_matching = llvm_utils->is_equal_by_value(key, original_key, module, + key_asr_type); + LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var); + } + builder->CreateBr(mergeBB); + + + llvm_utils->start_new_block(elseBB); + llvm_utils->start_new_block(mergeBB); + // TODO: Allow safe exit if pos becomes key_hash again. + // Ideally should not happen as dict will be resized once + // load factor touches a threshold (which will always be less than 1) + // so there will be some key which will not be set. However for safety + // we can add an exit from the loop with a error message. + llvm::Value *cond = builder->CreateAnd(is_key_set, builder->CreateNot( + LLVM::CreateLoad(*builder, is_key_matching_var))); + builder->CreateCondBr(cond, loopbody, loopend); + } + + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + pos = builder->CreateAdd(pos, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + pos = builder->CreateSRem(pos, capacity); + LLVM::CreateStore(*builder, pos, pos_ptr); + } + + + builder->CreateBr(loophead); + + + // end + llvm_utils->start_new_block(loopend); + } + + void LLVMDictOptimizedLinearProbing::resolve_collision( + llvm::Value* capacity, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* key_list, + llvm::Value* key_mask, llvm::Module& module, + ASR::ttype_t* key_asr_type, bool for_read) { if( !are_iterators_set ) { if( !for_read ) { pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); @@ -648,15 +731,45 @@ namespace LFortran { llvm_utils->start_new_block(loopend); } - void LLVMDict::linear_probing_for_write(llvm::Value* dict, llvm::Value* key_hash, - llvm::Value* key, llvm::Value* value, - llvm::Module& module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type) { + void LLVMDict::resolve_collision_for_write( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* value, + llvm::Module& module, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type) { llvm::Value* key_list = get_key_list(dict); llvm::Value* value_list = get_value_list(dict); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); - linear_probing(capacity, key_hash, key, key_list, key_mask, module, key_asr_type); + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm_utils->list_api->write_item(key_list, pos, key, + key_asr_type, module, false); + llvm_utils->list_api->write_item(value_list, pos, value, + value_asr_type, module, false); + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, pos)); + llvm::Value* is_slot_empty = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); + is_slot_empty = builder->CreateZExt(is_slot_empty, llvm::Type::getInt32Ty(context)); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + LLVM::CreateStore(*builder, builder->CreateAdd(occupancy, is_slot_empty), + occupancy_ptr); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)), + llvm_utils->create_ptr_gep(key_mask, pos)); + } + + void LLVMDictOptimizedLinearProbing::resolve_collision_for_write( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* value, + llvm::Module& module, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm_utils->list_api->write_item(key_list, pos, key, key_asr_type, module, false); @@ -681,9 +794,24 @@ namespace LFortran { LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(key_mask, pos)); } - llvm::Value* LLVMDict::linear_probing_for_read(llvm::Value* dict, llvm::Value* key_hash, - llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type) { + llvm::Value* LLVMDict::resolve_collision_for_read( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, true, false); + return item; + } + + llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type) { llvm::Value* key_list = get_key_list(dict); llvm::Value* value_list = get_value_list(dict); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); @@ -707,7 +835,7 @@ namespace LFortran { builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); { - linear_probing(capacity, key_hash, key, key_list, key_mask, + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type, true); } llvm_utils->start_new_block(mergeBB); @@ -809,7 +937,7 @@ namespace LFortran { llvm::Value* value = llvm_utils->list_api->read_item(value_list, idx, LLVM::is_llvm_struct(value_asr_type), false); llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module); - linear_probing(current_capacity, key_hash, key, new_key_list, + this->resolve_collision(current_capacity, key_hash, key, new_key_list, new_key_mask, *module, key_asr_type); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm::Value* key_dest = llvm_utils->list_api->read_item(new_key_list, pos, @@ -883,8 +1011,8 @@ namespace LFortran { rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type); llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module); - linear_probing_for_write(dict, key_hash, key, value, *module, - key_asr_type, value_asr_type); + this->resolve_collision_for_write(dict, key_hash, key, value, *module, + key_asr_type, value_asr_type); } llvm::Value* LLVMDict::read_item(llvm::Value* dict, llvm::Value* key, @@ -892,8 +1020,8 @@ namespace LFortran { bool get_pointer) { llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, module); - llvm::Value* value_ptr = linear_probing_for_read(dict, key_hash, key, module, - key_asr_type); + llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module, + key_asr_type); if( get_pointer ) { return value_ptr; } @@ -921,6 +1049,10 @@ namespace LFortran { return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict)); } + LLVMDict::~LLVMDict() {} + + LLVMDictOptimizedLinearProbing::~LLVMDictOptimizedLinearProbing() {} + void LLVMList::resize_if_needed(llvm::Value* list, llvm::Value* n, llvm::Value* capacity, int32_t type_size, llvm::Type* el_type, llvm::Module& module) { From 74bdd1db476c9e29a8f37201a63668726eeda8f8 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Aug 2022 15:25:02 +0530 Subject: [PATCH 6/7] Use LLVMDictOptimizedLinearProbing by default --- src/libasr/codegen/asr_to_llvm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 5e2cef59c8..ef7cd7332a 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -238,7 +238,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_utils(std::make_unique(context, builder.get())), list_api(std::make_unique(context, llvm_utils.get(), builder.get())), tuple_api(std::make_unique(context, llvm_utils.get(), builder.get())), - dict_api(std::make_unique(context, llvm_utils.get(), builder.get())), + dict_api(std::make_unique(context, llvm_utils.get(), builder.get())), arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context, builder.get(), llvm_utils.get(), From f525230bb87c3074b4ac8f30f60fc5836ced32de Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 19 Aug 2022 15:40:08 +0530 Subject: [PATCH 7/7] Clear typecode2dicttype in LLVMDict destructor --- src/libasr/codegen/llvm_utils.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index f08f55bade..a1d93ee9b5 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1049,7 +1049,9 @@ namespace LFortran { return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict)); } - LLVMDict::~LLVMDict() {} + LLVMDict::~LLVMDict() { + typecode2dicttype.clear(); + } LLVMDictOptimizedLinearProbing::~LLVMDictOptimizedLinearProbing() {}