From 2dabb10a08a8b132dd2d653ef83f0a8df4862a78 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Wed, 12 Apr 2023 12:43:11 +0530 Subject: [PATCH 1/7] C: Support dict item with default value --- src/libasr/codegen/asr_to_c_cpp.h | 13 ++++++++++--- src/libasr/codegen/c_utils.h | 28 ++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index 9b220a7384..dd24a4e81b 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -1065,15 +1065,22 @@ R"(#include void visit_DictItem(const ASR::DictItem_t& x) { ASR::Dict_t* dict_type = ASR::down_cast( ASRUtils::expr_type(x.m_a)); - std::string dict_get_fun = c_ds_api->get_dict_get_func(dict_type); - this->visit_expr(*x.m_a); std::string d_var = std::move(src); this->visit_expr(*x.m_key); std::string k = std::move(src); - src = dict_get_fun + "(&" + d_var + ", " + k + ")"; + if (x.m_default) { + this->visit_expr(*x.m_default); + std::string def_value = std::move(src); + std::string dict_get_fun = c_ds_api->get_dict_get_func(dict_type, + true); + src = dict_get_fun + "(&" + d_var + ", " + k + ", " + def_value + ")"; + } else { + std::string dict_get_fun = c_ds_api->get_dict_get_func(dict_type); + src = dict_get_fun + "(&" + d_var + ", " + k + ")"; + } } void visit_ListAppend(const ASR::ListAppend_t& x) { diff --git a/src/libasr/codegen/c_utils.h b/src/libasr/codegen/c_utils.h index ffb9fb8717..6f2557bf81 100644 --- a/src/libasr/codegen/c_utils.h +++ b/src/libasr/codegen/c_utils.h @@ -1126,8 +1126,11 @@ class CCPPDSUtils { return typecodeToDSfuncs[dict_type_code]["dict_insert"]; } - std::string get_dict_get_func(ASR::Dict_t* d_type) { + std::string get_dict_get_func(ASR::Dict_t* d_type, bool with_fallback=false) { std::string dict_type_code = ASRUtils::get_type_code((ASR::ttype_t*)d_type, true); + if (with_fallback) { + return typecodeToDSfuncs[dict_type_code]["dict_get_fb"]; + } return typecodeToDSfuncs[dict_type_code]["dict_get"]; } @@ -1177,6 +1180,7 @@ class CCPPDSUtils { dict_resize(dict_type, dict_struct_type, dict_type_code); dict_insert(dict_type, dict_struct_type, dict_type_code); dict_get_item(dict_type, dict_struct_type, dict_type_code); + dict_get_item_with_fallback(dict_type, dict_struct_type, dict_type_code); dict_len(dict_type, dict_struct_type, dict_type_code); dict_pop(dict_type, dict_struct_type, dict_type_code); dict_deepcopy(dict_type, dict_struct_type, dict_type_code); @@ -1294,11 +1298,31 @@ class CCPPDSUtils { generated_code += indent + tab + "int j=k\%x->capacity, c = 0;\n"; generated_code += indent + tab + "while(ccapacity && x->present[j] && !(x->key[j] == k)) j=(j+1)\%x->capacity, c++;\n"; generated_code += indent + tab + "if (x->present[j] && x->key[j] == k) return x->value[j];\n"; - generated_code += indent + tab + "printf(\"Key not found\");\n"; + generated_code += indent + tab + "printf(\"Key not found\\n\");\n"; generated_code += indent + tab + "exit(1);\n"; generated_code += indent + "}\n\n"; } + void dict_get_item_with_fallback(ASR::Dict_t *dict_type, std::string dict_struct_type, + std::string dict_type_code) { + std::string indent(indentation_level * indentation_spaces, ' '); + std::string tab(indentation_spaces, ' '); + std::string dict_get_func = global_scope->get_unique_name("dict_get_item_fb_" + dict_type_code); + typecodeToDSfuncs[dict_type_code]["dict_get_fb"] = dict_get_func; + std::string key = CUtils::get_c_type_from_ttype_t(dict_type->m_key_type); + std::string val = CUtils::get_c_type_from_ttype_t(dict_type->m_value_type); + std::string signature = val + " " + dict_get_func + "(" + dict_struct_type + "* x, " +\ + key + " k, " + val + " dv)"; + func_decls += indent + "inline " + signature + ";\n"; + signature = indent + signature; + generated_code += indent + signature + " {\n"; + generated_code += indent + tab + "int j=k\%x->capacity, c = 0;\n"; + generated_code += indent + tab + "while(ccapacity && x->present[j] && !(x->key[j] == k)) j=(j+1)\%x->capacity, c++;\n"; + generated_code += indent + tab + "if (x->present[j] && x->key[j] == k) return x->value[j];\n"; + generated_code += indent + tab + "return dv;\n"; + generated_code += indent + "}\n\n"; + } + void dict_len(ASR::Dict_t *dict_type, std::string dict_struct_type, std::string dict_type_code) { std::string indent(indentation_level * indentation_spaces, ' '); From 9edfe5824917f6690f3d9055e0f739dbac781f64 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Wed, 12 Apr 2023 12:43:33 +0530 Subject: [PATCH 2/7] Add tests --- integration_tests/CMakeLists.txt | 1 + integration_tests/test_dict_11.py | 8 ++++++++ 2 files changed, 9 insertions(+) create mode 100644 integration_tests/test_dict_11.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index a2c299e0b3..f2171be481 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -361,6 +361,7 @@ RUN(NAME test_dict_07 LABELS cpython llvm) RUN(NAME test_dict_08 LABELS cpython llvm c) RUN(NAME test_dict_09 LABELS cpython llvm c) RUN(NAME test_dict_10 LABELS cpython llvm) # TODO: Add support of dict with string in C backend +RUN(NAME test_dict_11 LABELS cpython c) # TODO: Add LLVM support RUN(NAME test_for_loop LABELS cpython llvm c) RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64) RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_dict_11.py b/integration_tests/test_dict_11.py new file mode 100644 index 0000000000..beecb6f5f9 --- /dev/null +++ b/integration_tests/test_dict_11.py @@ -0,0 +1,8 @@ +from lpython import i32 + +def test_dict_11(): + num : dict[i32, i32] + num = {11: 22, 33: 44, 55: 66} + assert num.get(7, -1) == -1 + +test_dict_11() From cbb2653c027fa34929ca15aefd7184a77081655c Mon Sep 17 00:00:00 2001 From: Smit-create Date: Thu, 20 Apr 2023 11:00:12 +0530 Subject: [PATCH 3/7] LLVM: Consider default value in get --- src/libasr/codegen/asr_to_llvm.cpp | 17 ++- src/libasr/codegen/llvm_utils.cpp | 192 +++++++++++++++++++++++++++++ src/libasr/codegen/llvm_utils.h | 39 +++++- 3 files changed, 244 insertions(+), 4 deletions(-) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 8e394b8859..2f9ce9bba0 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1793,10 +1793,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr_wrapper(x.m_key, true); ptr_loads = ptr_loads_copy; llvm::Value *key = tmp; - - set_dict_api(dict_type); - tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type, + if (x.m_default) { + llvm::Type *val_type = get_type_from_ttype_t_util(dict_type->m_value_type); + llvm::Value *def_value_ptr = builder->CreateAlloca(val_type, nullptr); + ptr_loads = !LLVM::is_llvm_struct(dict_type->m_value_type); + this->visit_expr_wrapper(x.m_default, true); + ptr_loads = ptr_loads_copy; + builder->CreateStore(tmp, def_value_ptr); + set_dict_api(dict_type); + tmp = llvm_utils->dict_api->get_item(pdict, key, *module, dict_type, def_value_ptr, LLVM::is_llvm_struct(dict_type->m_value_type)); + } else { + set_dict_api(dict_type); + tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type, + LLVM::is_llvm_struct(dict_type->m_value_type)); + } } void visit_DictPop(const ASR::DictPop_t& x) { diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 2170fe146d..cd2ae71850 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1545,6 +1545,53 @@ namespace LCompilers { return item; } + void LLVMDict::_check_key_present_or_default(llvm::Module& module, llvm::Value *key, llvm::Value *key_list, + ASR::ttype_t* key_asr_type, llvm::Value *value_list, llvm::Value *pos, + llvm::Value *def_value, llvm::Value* &result) { + llvm::Function *fn_single_match = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB_single_match = llvm::BasicBlock::Create(context, "then", fn_single_match); + llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm_utils->list_api->read_item(key_list, pos, false, module, + LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); + builder->CreateCondBr(is_key_matching, thenBB_single_match, elseBB_single_match); + builder->SetInsertPoint(thenBB_single_match); + { + llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, + false, module, false); + LLVM::CreateStore(*builder, item, result); + } + builder->CreateBr(mergeBB_single_match); + llvm_utils->start_new_block(elseBB_single_match); + { + LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), result); + } + llvm_utils->start_new_block(mergeBB_single_match); + } + + llvm::Value* LLVMDict::resolve_collision_for_read_with_default( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + llvm::Value* def_value) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type, true); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + std::pair llvm_key = std::make_pair( + ASRUtils::get_type_code(key_asr_type), + ASRUtils::get_type_code(value_asr_type) + ); + llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; + llvm::Value* result = builder->CreateAlloca(value_type, nullptr); + _check_key_present_or_default(module, key, key_list, key_asr_type, value_list, + pos, def_value, result); + return result; + } + llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, @@ -1611,6 +1658,65 @@ namespace LCompilers { return item; } + llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_default( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + llvm::Value *def_value) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + if( !are_iterators_set ) { + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } + std::pair llvm_key = std::make_pair( + ASRUtils::get_type_code(key_asr_type), + ASRUtils::get_type_code(value_asr_type) + ); + llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; + llvm::Value* result = builder->CreateAlloca(value_type, nullptr); + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, key_hash)); + llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + llvm::Function *fn_single_match = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB_single_match = llvm::BasicBlock::Create(context, "then", fn_single_match); + llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm_utils->list_api->read_item(key_list, key_hash, false, module, + LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); + builder->CreateCondBr(is_key_matching, thenBB_single_match, elseBB_single_match); + builder->SetInsertPoint(thenBB_single_match); + LLVM::CreateStore(*builder, key_hash, pos_ptr); + builder->CreateBr(mergeBB_single_match); + llvm_utils->start_new_block(elseBB_single_match); + { + LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), result); + } + llvm_utils->start_new_block(mergeBB_single_match); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, + module, key_asr_type, true); + } + llvm_utils->start_new_block(mergeBB); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + _check_key_present_or_default(module, key, key_list, key_asr_type, value_list, + pos, def_value, result); + return result; + } + llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, @@ -1671,6 +1777,59 @@ namespace LCompilers { return tmp_value_ptr; } + llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read_with_default( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Value *def_value) { + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict)); + llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Type* kv_struct_type = get_key_value_pair_type(key_asr_type, value_asr_type); + this->resolve_collision(capacity, key_hash, key, key_value_pair_linked_list, + kv_struct_type, key_mask, module, key_asr_type); + std::pair llvm_key = std::make_pair( + ASRUtils::get_type_code(key_asr_type), + ASRUtils::get_type_code(value_asr_type) + ); + llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; + llvm::Value* tmp_value_ptr_local = nullptr; + if( !are_iterators_set ) { + tmp_value_ptr = builder->CreateAlloca(value_type, nullptr); + tmp_value_ptr_local = tmp_value_ptr; + } else { + tmp_value_ptr_local = builder->CreateBitCast(tmp_value_ptr, value_type->getPointerTo()); + } + llvm::Function *fn_single_match = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB_single_match = llvm::BasicBlock::Create(context, "then", fn_single_match); + llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, key_hash)); + llvm::Value* does_kv_exists = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + does_kv_exists = builder->CreateAnd(does_kv_exists, + builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) + ); + builder->CreateCondBr(does_kv_exists, thenBB_single_match, elseBB_single_match); + builder->SetInsertPoint(thenBB_single_match); + { + llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); + llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); + llvm::Value* value = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 1)); + LLVM::CreateStore(*builder, value, tmp_value_ptr_local); + + } + builder->CreateBr(mergeBB_single_match); + llvm_utils->start_new_block(elseBB_single_match); + { + LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), tmp_value_ptr_local); + } + llvm_utils->start_new_block(mergeBB_single_match); + return tmp_value_ptr; + } + llvm::Value* LLVMDictInterface::get_key_hash(llvm::Value* capacity, llvm::Value* key, ASR::ttype_t* key_asr_type, llvm::Module& module) { // Write specialised hash functions for intrinsic types @@ -2115,6 +2274,20 @@ namespace LCompilers { return LLVM::CreateLoad(*builder, value_ptr); } + llvm::Value* LLVMDict::get_item(llvm::Value* dict, llvm::Value* key, + llvm::Module& module, ASR::Dict_t* dict_type, llvm::Value* def_value, + bool get_pointer) { + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module); + llvm::Value* value_ptr = this->resolve_collision_for_read_with_default(dict, key_hash, key, module, + dict_type->m_key_type, dict_type->m_value_type, + def_value); + if( get_pointer ) { + return value_ptr; + } + return LLVM::CreateLoad(*builder, value_ptr); + } + llvm::Value* LLVMDictSeparateChaining::read_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer) { llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); @@ -2133,6 +2306,25 @@ namespace LCompilers { return LLVM::CreateLoad(*builder, value_ptr); } + llvm::Value* LLVMDictSeparateChaining::get_item(llvm::Value* dict, llvm::Value* key, + llvm::Module& module, ASR::Dict_t* dict_type, llvm::Value* def_value, bool get_pointer) { + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module); + llvm::Value* value_ptr = this->resolve_collision_for_read_with_default(dict, key_hash, key, module, + dict_type->m_key_type, dict_type->m_value_type, + def_value); + std::pair llvm_key = std::make_pair( + ASRUtils::get_type_code(dict_type->m_key_type), + ASRUtils::get_type_code(dict_type->m_value_type) + ); + llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; + value_ptr = builder->CreateBitCast(value_ptr, value_type->getPointerTo()); + if( get_pointer ) { + return value_ptr; + } + return LLVM::CreateLoad(*builder, value_ptr); + } + llvm::Value* LLVMDict::pop_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer) { diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 4b33549e79..0cc64cc630 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -95,7 +95,7 @@ namespace LCompilers { return ASR::is_a(*asr_type) || ASR::is_a(*asr_type) || ASR::is_a(*asr_type) || - ASR::is_a(*asr_type)|| + ASR::is_a(*asr_type)|| ASR::is_a(*asr_type); } } @@ -347,6 +347,11 @@ namespace LCompilers { llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) = 0; + virtual + llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Value* def_value) = 0; + virtual void rehash(llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, @@ -370,6 +375,11 @@ namespace LCompilers { llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer=false) = 0; + virtual + llvm::Value* get_item(llvm::Value* dict, llvm::Value* key, + llvm::Module& module, ASR::Dict_t* dict_type, llvm::Value* def_value, + bool get_pointer=false) = 0; + virtual llvm::Value* pop_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, @@ -434,10 +444,19 @@ namespace LCompilers { ASR::ttype_t* value_asr_type, std::map>& name2memidx); + void _check_key_present_or_default(llvm::Module& module, llvm::Value *key, llvm::Value *key_list, + ASR::ttype_t* key_asr_type, llvm::Value *value_list, llvm::Value *pos, + llvm::Value *def_value, llvm::Value* &result); + llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + llvm::Value* def_value); + void rehash(llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, std::map>& name2memidx); @@ -457,6 +476,10 @@ namespace LCompilers { llvm::Module& module, ASR::Dict_t* key_asr_type, bool get_pointer=false); + llvm::Value* get_item(llvm::Value* dict, llvm::Value* key, + llvm::Module& module, ASR::Dict_t* key_asr_type, llvm::Value* def_value, + bool get_pointer=false); + llvm::Value* pop_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer=false); @@ -496,6 +519,11 @@ namespace LCompilers { llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + llvm::Value *def_value); + virtual ~LLVMDictOptimizedLinearProbing(); }; @@ -564,6 +592,11 @@ namespace LCompilers { llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + llvm::Value* def_value); + void rehash(llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, std::map>& name2memidx); @@ -583,6 +616,10 @@ namespace LCompilers { llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer=false); + llvm::Value* get_item(llvm::Value* dict, llvm::Value* key, + llvm::Module& module, ASR::Dict_t* dict_type, llvm::Value* def_value, + bool get_pointer=false); + llvm::Value* pop_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer=false); From 9e5da4c5ade73ff1d69033732c683a8fd593515d Mon Sep 17 00:00:00 2001 From: Smit-create Date: Thu, 20 Apr 2023 11:00:24 +0530 Subject: [PATCH 4/7] LLVM: Add and update tests --- integration_tests/CMakeLists.txt | 2 +- integration_tests/test_dict_11.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index f2171be481..815b9e052a 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -361,7 +361,7 @@ RUN(NAME test_dict_07 LABELS cpython llvm) RUN(NAME test_dict_08 LABELS cpython llvm c) RUN(NAME test_dict_09 LABELS cpython llvm c) RUN(NAME test_dict_10 LABELS cpython llvm) # TODO: Add support of dict with string in C backend -RUN(NAME test_dict_11 LABELS cpython c) # TODO: Add LLVM support +RUN(NAME test_dict_11 LABELS cpython llvm c) RUN(NAME test_for_loop LABELS cpython llvm c) RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64) RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_dict_11.py b/integration_tests/test_dict_11.py index beecb6f5f9..57c64fe1b8 100644 --- a/integration_tests/test_dict_11.py +++ b/integration_tests/test_dict_11.py @@ -4,5 +4,13 @@ def test_dict_11(): num : dict[i32, i32] num = {11: 22, 33: 44, 55: 66} assert num.get(7, -1) == -1 + assert num.get(11, -1) == 22 + assert num.get(33, -1) == 44 + assert num.get(55, -1) == 66 + assert num.get(72, -110) == -110 + d : dict[i32, str] + d = {1: "1", 2: "22", 3: "333"} + assert d.get(2, "00") == "22" + assert d.get(21, "nokey") == "nokey" test_dict_11() From 99609225644a4a43157d421146741496216e5caa Mon Sep 17 00:00:00 2001 From: Smit-create Date: Thu, 20 Apr 2023 11:33:33 +0530 Subject: [PATCH 5/7] Refactor and use create_if_else --- src/libasr/codegen/asr_to_llvm.cpp | 40 +++----------- src/libasr/codegen/llvm_utils.cpp | 85 +++++++----------------------- src/libasr/codegen/llvm_utils.h | 29 ++++++++++ 3 files changed, 55 insertions(+), 99 deletions(-) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 2f9ce9bba0..0657303358 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -281,34 +281,6 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor builder->SetInsertPoint(bb); } - // Note: `create_if_else` and `create_loop` are optional APIs - // that do not have to be used. Many times, for more complicated - // things, it might be more readable to just use the LLVM API - // without any extra layer on top. In some other cases, it might - // be more readable to use this abstraction. - // The `if_block` and `else_block` must generate one or more blocks. In - // addition, the `if_block` must not be terminated, we terminate it - // ourselves. The `else_block` can be either terminated or not. - template - void create_if_else(llvm::Value * cond, IF if_block, ELSE else_block) { - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - - builder->CreateCondBr(cond, thenBB, elseBB); - builder->SetInsertPoint(thenBB); { - if_block(); - } - builder->CreateBr(mergeBB); - - start_new_block(elseBB); { - else_block(); - } - start_new_block(mergeBB); - } - template void create_loop(char *name, Cond condition, Body loop_body) { dict_api_lp->set_iterators(); @@ -1487,7 +1459,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } llvm::Value *cond = arr_descr->get_is_allocated_flag(tmp); - create_if_else(cond, [=]() { + llvm_utils->create_if_else(cond, [=]() { call_lfortran_free(free_fn); }, [](){}); } @@ -5151,7 +5123,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor void visit_If(const ASR::If_t &x) { this->visit_expr_wrapper(x.m_test, true); - create_if_else(tmp, [=]() { + llvm_utils->create_if_else(tmp, [=]() { for (size_t i=0; ivisit_stmt(*x.m_body[i]); } @@ -5168,7 +5140,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value *cond = tmp; llvm::Value *then_val = nullptr; llvm::Value *else_val = nullptr; - create_if_else(cond, [=, &then_val]() { + llvm_utils->create_if_else(cond, [=, &then_val]() { this->visit_expr_wrapper(x.m_body, true); then_val = tmp; }, [=, &else_val]() { @@ -5324,7 +5296,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } switch (x.m_op) { case ASR::logicalbinopType::And: { - create_if_else(cond, [&, result, left_val]() { + llvm_utils->create_if_else(cond, [&, result, left_val]() { LLVM::CreateStore(*builder, left_val, result); }, [&, result, right_val]() { LLVM::CreateStore(*builder, right_val, result); @@ -5333,7 +5305,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor break; }; case ASR::logicalbinopType::Or: { - create_if_else(cond, [&, result, right_val]() { + llvm_utils->create_if_else(cond, [&, result, right_val]() { LLVM::CreateStore(*builder, right_val, result); }, [&, result, left_val]() { @@ -5865,7 +5837,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor void visit_Assert(const ASR::Assert_t &x) { if (compiler_options.emit_debug_info) debug_emit_loc(x); this->visit_expr_wrapper(x.m_test, true); - create_if_else(tmp, []() {}, [=]() { + llvm_utils->create_if_else(tmp, []() {}, [=]() { if (compiler_options.emit_debug_info) { llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr(infile); llvm::Value *fmt_ptr1 = llvm::ConstantInt::get(context, llvm::APInt( diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index cd2ae71850..1e0b82e864 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1548,26 +1548,16 @@ namespace LCompilers { void LLVMDict::_check_key_present_or_default(llvm::Module& module, llvm::Value *key, llvm::Value *key_list, ASR::ttype_t* key_asr_type, llvm::Value *value_list, llvm::Value *pos, llvm::Value *def_value, llvm::Value* &result) { - llvm::Function *fn_single_match = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB_single_match = llvm::BasicBlock::Create(context, "then", fn_single_match); - llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, llvm_utils->list_api->read_item(key_list, pos, false, module, LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); - builder->CreateCondBr(is_key_matching, thenBB_single_match, elseBB_single_match); - builder->SetInsertPoint(thenBB_single_match); - { + llvm_utils->create_if_else(is_key_matching, [&]() { llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, false, module, false); LLVM::CreateStore(*builder, item, result); - } - builder->CreateBr(mergeBB_single_match); - llvm_utils->start_new_block(elseBB_single_match); - { + }, [=]() { LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), result); - } - llvm_utils->start_new_block(mergeBB_single_match); + }); } llvm::Value* LLVMDict::resolve_collision_for_read_with_default( @@ -1621,19 +1611,13 @@ namespace LCompilers { // In the above case we will end up returning value for a key // which is not present in the dict. Instead we should return an error // which is done in the below code. - llvm::Function *fn_single_match = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB_single_match = llvm::BasicBlock::Create(context, "then", fn_single_match); - llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, llvm_utils->list_api->read_item(key_list, key_hash, false, module, LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); - builder->CreateCondBr(is_key_matching, thenBB_single_match, elseBB_single_match); - builder->SetInsertPoint(thenBB_single_match); - LLVM::CreateStore(*builder, key_hash, pos_ptr); - builder->CreateBr(mergeBB_single_match); - llvm_utils->start_new_block(elseBB_single_match); - { + + llvm_utils->create_if_else(is_key_matching, [=]() { + LLVM::CreateStore(*builder, key_hash, pos_ptr); + }, [&]() { std::string message = "The dict does not contain the specified key"; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); @@ -1642,8 +1626,7 @@ namespace LCompilers { llvm::Value *exit_code = llvm::ConstantInt::get(context, llvm::APInt(32, exit_code_int)); exit(context, module, *builder, exit_code); - } - llvm_utils->start_new_block(mergeBB_single_match); + }); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -1687,22 +1670,14 @@ namespace LCompilers { builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB); builder->SetInsertPoint(thenBB); { - llvm::Function *fn_single_match = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB_single_match = llvm::BasicBlock::Create(context, "then", fn_single_match); - llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, llvm_utils->list_api->read_item(key_list, key_hash, false, module, LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); - builder->CreateCondBr(is_key_matching, thenBB_single_match, elseBB_single_match); - builder->SetInsertPoint(thenBB_single_match); - LLVM::CreateStore(*builder, key_hash, pos_ptr); - builder->CreateBr(mergeBB_single_match); - llvm_utils->start_new_block(elseBB_single_match); - { + llvm_utils->create_if_else(is_key_matching, [=]() { + LLVM::CreateStore(*builder, key_hash, pos_ptr); + }, [=]() { LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), result); - } - llvm_utils->start_new_block(mergeBB_single_match); + }); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -1740,10 +1715,6 @@ namespace LCompilers { } else { tmp_value_ptr_local = builder->CreateBitCast(tmp_value_ptr, value_type->getPointerTo()); } - llvm::Function *fn_single_match = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB_single_match = llvm::BasicBlock::Create(context, "then", fn_single_match); - llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, key_hash)); llvm::Value* does_kv_exists = builder->CreateICmpEQ(key_mask_value, @@ -1752,18 +1723,13 @@ namespace LCompilers { builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr), llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) ); - builder->CreateCondBr(does_kv_exists, thenBB_single_match, elseBB_single_match); - builder->SetInsertPoint(thenBB_single_match); - { + + llvm_utils->create_if_else(does_kv_exists, [=]() { llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); llvm::Value* value = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 1)); LLVM::CreateStore(*builder, value, tmp_value_ptr_local); - - } - builder->CreateBr(mergeBB_single_match); - llvm_utils->start_new_block(elseBB_single_match); - { + }, [&]() { std::string message = "The dict does not contain the specified key"; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); @@ -1772,8 +1738,7 @@ namespace LCompilers { llvm::Value *exit_code = llvm::ConstantInt::get(context, llvm::APInt(32, exit_code_int)); exit(context, module, *builder, exit_code); - } - llvm_utils->start_new_block(mergeBB_single_match); + }); return tmp_value_ptr; } @@ -1800,10 +1765,6 @@ namespace LCompilers { } else { tmp_value_ptr_local = builder->CreateBitCast(tmp_value_ptr, value_type->getPointerTo()); } - llvm::Function *fn_single_match = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB_single_match = llvm::BasicBlock::Create(context, "then", fn_single_match); - llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, key_hash)); llvm::Value* does_kv_exists = builder->CreateICmpEQ(key_mask_value, @@ -1812,21 +1773,15 @@ namespace LCompilers { builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr), llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) ); - builder->CreateCondBr(does_kv_exists, thenBB_single_match, elseBB_single_match); - builder->SetInsertPoint(thenBB_single_match); - { + + llvm_utils->create_if_else(does_kv_exists, [=]() { llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); llvm::Value* value = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 1)); LLVM::CreateStore(*builder, value, tmp_value_ptr_local); - - } - builder->CreateBr(mergeBB_single_match); - llvm_utils->start_new_block(elseBB_single_match); - { + }, [&]() { LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), tmp_value_ptr_local); - } - llvm_utils->start_new_block(mergeBB_single_match); + }); return tmp_value_ptr; } diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 0cc64cc630..81da377528 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -150,6 +150,35 @@ namespace LCompilers { ASR::ttype_t* asr_type, llvm::Module* module, std::map>& name2memidx); + + // Note: `llvm_utils->create_if_else` and `create_loop` are optional APIs + // that do not have to be used. Many times, for more complicated + // things, it might be more readable to just use the LLVM API + // without any extra layer on top. In some other cases, it might + // be more readable to use this abstraction. + // The `if_block` and `else_block` must generate one or more blocks. In + // addition, the `if_block` must not be terminated, we terminate it + // ourselves. The `else_block` can be either terminated or not. + template + void create_if_else(llvm::Value * cond, IF if_block, ELSE else_block) { + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + + builder->CreateCondBr(cond, thenBB, elseBB); + builder->SetInsertPoint(thenBB); { + if_block(); + } + builder->CreateBr(mergeBB); + + start_new_block(elseBB); { + else_block(); + } + start_new_block(mergeBB); + } + }; // LLVMUtils class LLVMList { From 979a759823b85b17d6c283013df368fad5d4344e Mon Sep 17 00:00:00 2001 From: Smit-create Date: Thu, 20 Apr 2023 11:50:36 +0530 Subject: [PATCH 6/7] Catch key not found error --- src/libasr/codegen/llvm_utils.cpp | 18 +++++++++++++++++- tests/errors/test_dict14.py | 8 ++++++++ .../reference/runtime-test_dict14-421fe53.json | 13 +++++++++++++ .../runtime-test_dict14-421fe53.stderr | 1 + tests/tests.toml | 4 ++++ 5 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 tests/errors/test_dict14.py create mode 100644 tests/reference/runtime-test_dict14-421fe53.json create mode 100644 tests/reference/runtime-test_dict14-421fe53.stderr diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 1e0b82e864..e1393c3d34 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1636,8 +1636,24 @@ namespace LCompilers { } llvm_utils->start_new_block(mergeBB); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + // Check if the actual key is present or not + llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm_utils->list_api->read_item(key_list, pos, false, module, + LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); + + llvm_utils->create_if_else(is_key_matching, [&]() { + }, [&]() { + std::string message = "The dict does not contain the specified key"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + }); llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, - false, module, true); + false, module, true);; return item; } diff --git a/tests/errors/test_dict14.py b/tests/errors/test_dict14.py new file mode 100644 index 0000000000..7a10bb5704 --- /dev/null +++ b/tests/errors/test_dict14.py @@ -0,0 +1,8 @@ +from lpython import i32 + +def key_not_present(): + x: dict[i32, i32] = {} + x = {1: 1, 2: 2} + print(x[10]) + +key_not_present() diff --git a/tests/reference/runtime-test_dict14-421fe53.json b/tests/reference/runtime-test_dict14-421fe53.json new file mode 100644 index 0000000000..ae6a837f77 --- /dev/null +++ b/tests/reference/runtime-test_dict14-421fe53.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-test_dict14-421fe53", + "cmd": "lpython {infile}", + "infile": "tests/errors/test_dict14.py", + "infile_hash": "c81e4a1e050c87d04f537b49fcfb0b479a5ce3ed7735f90d2c347dba", + "outfile": null, + "outfile_hash": null, + "stdout": null, + "stdout_hash": null, + "stderr": "runtime-test_dict14-421fe53.stderr", + "stderr_hash": "cb46ef04db0862506d688ebe8830a50afaaead9b0d29b0c007dd149a", + "returncode": 1 +} \ No newline at end of file diff --git a/tests/reference/runtime-test_dict14-421fe53.stderr b/tests/reference/runtime-test_dict14-421fe53.stderr new file mode 100644 index 0000000000..e8c90e4e1d --- /dev/null +++ b/tests/reference/runtime-test_dict14-421fe53.stderr @@ -0,0 +1 @@ +KeyError: The dict does not contain the specified key diff --git a/tests/tests.toml b/tests/tests.toml index 549370df82..809a678b97 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -813,6 +813,10 @@ asr = true filename = "errors/test_dict13.py" asr = true +[[test]] +filename = "errors/test_dict14.py" +run = true + [[test]] filename = "errors/test_zero_division.py" asr = true From a01c27595444995e72c5cf39631973d5bd42d31e Mon Sep 17 00:00:00 2001 From: Smit-create Date: Sat, 22 Apr 2023 11:29:08 +0530 Subject: [PATCH 7/7] Add option for bound checking in dict --- src/libasr/codegen/asr_to_llvm.cpp | 1 + src/libasr/codegen/llvm_utils.cpp | 135 +++++++++++++++++++++++++++-- src/libasr/codegen/llvm_utils.h | 23 ++++- 3 files changed, 150 insertions(+), 9 deletions(-) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 0657303358..b5c43e5e6a 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1778,6 +1778,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } else { set_dict_api(dict_type); tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type, + compiler_options.enable_bounds_checking, LLVM::is_llvm_struct(dict_type->m_value_type)); } } diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index e1393c3d34..3072540619 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1545,6 +1545,36 @@ namespace LCompilers { return item; } + llvm::Value* LLVMDict::resolve_collision_for_read_with_bound_check( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type, true); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm_utils->list_api->read_item(key_list, pos, false, module, + LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); + + llvm_utils->create_if_else(is_key_matching, [&]() { + }, [&]() { + std::string message = "The dict does not contain the specified key"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + }); + llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, + false, module, false); + return item; + } + void LLVMDict::_check_key_present_or_default(llvm::Module& module, llvm::Value *key, llvm::Value *key_list, ASR::ttype_t* key_asr_type, llvm::Value *value_list, llvm::Value *pos, llvm::Value *def_value, llvm::Value* &result) { @@ -1582,7 +1612,7 @@ namespace LCompilers { return result; } - llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read( + llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) { @@ -1653,7 +1683,58 @@ namespace LCompilers { exit(context, module, *builder, exit_code); }); llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, - false, module, true);; + false, module, true); + return item; + } + + llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + if( !are_iterators_set ) { + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, key_hash)); + llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + // A single by value comparison is needed even though + // we don't need to do linear probing. This is because + // the user can provide a key which is absent in the dict + // but is giving the same hash value as one of the keys present in the dict. + // In the above case we will end up returning value for a key + // which is not present in the dict. Instead we should return an error + // which is done in the below code. + llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm_utils->list_api->read_item(key_list, key_hash, false, module, + LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); + + llvm_utils->create_if_else(is_key_matching, [=]() { + LLVM::CreateStore(*builder, key_hash, pos_ptr); + }, [=]() { + }); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, + module, key_asr_type, true); + } + llvm_utils->start_new_block(mergeBB); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, + false, module, true); return item; } @@ -1709,6 +1790,36 @@ namespace LCompilers { } llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read( + llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict)); + llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Type* kv_struct_type = get_key_value_pair_type(key_asr_type, value_asr_type); + this->resolve_collision(capacity, key_hash, key, key_value_pair_linked_list, + kv_struct_type, key_mask, module, key_asr_type); + std::pair llvm_key = std::make_pair( + ASRUtils::get_type_code(key_asr_type), + ASRUtils::get_type_code(value_asr_type) + ); + llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; + llvm::Value* tmp_value_ptr_local = nullptr; + if( !are_iterators_set ) { + tmp_value_ptr = builder->CreateAlloca(value_type, nullptr); + tmp_value_ptr_local = tmp_value_ptr; + } else { + tmp_value_ptr_local = builder->CreateBitCast(tmp_value_ptr, value_type->getPointerTo()); + } + llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); + llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); + llvm::Value* value = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 1)); + LLVM::CreateStore(*builder, value, tmp_value_ptr_local); + return tmp_value_ptr; + } + + llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read_with_bound_check( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { @@ -2233,12 +2344,18 @@ namespace LCompilers { } llvm::Value* LLVMDict::read_item(llvm::Value* dict, llvm::Value* key, - llvm::Module& module, ASR::Dict_t* dict_type, + llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking, bool get_pointer) { llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module); - llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module, + llvm::Value* value_ptr; + if (enable_bounds_checking) { + value_ptr = this->resolve_collision_for_read_with_bound_check(dict, key_hash, key, module, dict_type->m_key_type, dict_type->m_value_type); + } else { + value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module, + dict_type->m_key_type, dict_type->m_value_type); + } if( get_pointer ) { return value_ptr; } @@ -2260,11 +2377,17 @@ namespace LCompilers { } llvm::Value* LLVMDictSeparateChaining::read_item(llvm::Value* dict, llvm::Value* key, - llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer) { + llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking, bool get_pointer) { llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module); - llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module, + llvm::Value* value_ptr; + if (enable_bounds_checking) { + value_ptr = this->resolve_collision_for_read_with_bound_check(dict, key_hash, key, module, dict_type->m_key_type, dict_type->m_value_type); + } else { + value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module, + dict_type->m_key_type, dict_type->m_value_type); + } std::pair llvm_key = std::make_pair( ASRUtils::get_type_code(dict_type->m_key_type), ASRUtils::get_type_code(dict_type->m_value_type) diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 81da377528..7967af1cb0 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -376,6 +376,11 @@ namespace LCompilers { llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) = 0; + virtual + llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) = 0; + virtual llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, @@ -401,7 +406,7 @@ namespace LCompilers { virtual llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, - llvm::Module& module, ASR::Dict_t* dict_type, + llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking, bool get_pointer=false) = 0; virtual @@ -481,6 +486,10 @@ namespace LCompilers { llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, @@ -502,7 +511,7 @@ namespace LCompilers { std::map>& name2memidx); llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, - llvm::Module& module, ASR::Dict_t* key_asr_type, + llvm::Module& module, ASR::Dict_t* key_asr_type, bool enable_bounds_checking, bool get_pointer=false); llvm::Value* get_item(llvm::Value* dict, llvm::Value* key, @@ -548,6 +557,10 @@ namespace LCompilers { llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, @@ -621,6 +634,10 @@ namespace LCompilers { llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, @@ -642,7 +659,7 @@ namespace LCompilers { std::map>& name2memidx); llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, - llvm::Module& module, ASR::Dict_t* dict_type, + llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking, bool get_pointer=false); llvm::Value* get_item(llvm::Value* dict, llvm::Value* key,