@@ -1545,6 +1545,36 @@ namespace LCompilers {
15451545 return item;
15461546 }
15471547
1548+ llvm::Value* LLVMDict::resolve_collision_for_read_with_bound_check (
1549+ llvm::Value* dict, llvm::Value* key_hash,
1550+ llvm::Value* key, llvm::Module& module ,
1551+ ASR::ttype_t * key_asr_type, ASR::ttype_t * /* value_asr_type*/ ) {
1552+ llvm::Value* key_list = get_key_list (dict);
1553+ llvm::Value* value_list = get_value_list (dict);
1554+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
1555+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
1556+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask, module , key_asr_type, true );
1557+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
1558+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
1559+ llvm_utils->list_api ->read_item (key_list, pos, false , module ,
1560+ LLVM::is_llvm_struct (key_asr_type)), module , key_asr_type);
1561+
1562+ llvm_utils->create_if_else (is_key_matching, [&]() {
1563+ }, [&]() {
1564+ std::string message = " The dict does not contain the specified key" ;
1565+ llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr (" KeyError: %s\n " );
1566+ llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr (message);
1567+ print_error (context, module , *builder, {fmt_ptr, fmt_ptr2});
1568+ int exit_code_int = 1 ;
1569+ llvm::Value *exit_code = llvm::ConstantInt::get (context,
1570+ llvm::APInt (32 , exit_code_int));
1571+ exit (context, module , *builder, exit_code);
1572+ });
1573+ llvm::Value* item = llvm_utils->list_api ->read_item (value_list, pos,
1574+ false , module , false );
1575+ return item;
1576+ }
1577+
15481578 void LLVMDict::_check_key_present_or_default (llvm::Module& module , llvm::Value *key, llvm::Value *key_list,
15491579 ASR::ttype_t * key_asr_type, llvm::Value *value_list, llvm::Value *pos,
15501580 llvm::Value *def_value, llvm::Value* &result) {
@@ -1582,7 +1612,7 @@ namespace LCompilers {
15821612 return result;
15831613 }
15841614
1585- llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read (
1615+ llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check (
15861616 llvm::Value* dict, llvm::Value* key_hash,
15871617 llvm::Value* key, llvm::Module& module ,
15881618 ASR::ttype_t * key_asr_type, ASR::ttype_t * /* value_asr_type*/ ) {
@@ -1653,7 +1683,58 @@ namespace LCompilers {
16531683 exit (context, module , *builder, exit_code);
16541684 });
16551685 llvm::Value* item = llvm_utils->list_api ->read_item (value_list, pos,
1656- false , module , true );;
1686+ false , module , true );
1687+ return item;
1688+ }
1689+
1690+ llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read (
1691+ llvm::Value* dict, llvm::Value* key_hash,
1692+ llvm::Value* key, llvm::Module& module ,
1693+ ASR::ttype_t * key_asr_type, ASR::ttype_t * /* value_asr_type*/ ) {
1694+ llvm::Value* key_list = get_key_list (dict);
1695+ llvm::Value* value_list = get_value_list (dict);
1696+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
1697+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
1698+ if ( !are_iterators_set ) {
1699+ pos_ptr = builder->CreateAlloca (llvm::Type::getInt32Ty (context), nullptr );
1700+ }
1701+ llvm::Function *fn = builder->GetInsertBlock ()->getParent ();
1702+ llvm::BasicBlock *thenBB = llvm::BasicBlock::Create (context, " then" , fn);
1703+ llvm::BasicBlock *elseBB = llvm::BasicBlock::Create (context, " else" );
1704+ llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create (context, " ifcont" );
1705+ llvm::Value* key_mask_value = LLVM::CreateLoad (*builder,
1706+ llvm_utils->create_ptr_gep (key_mask, key_hash));
1707+ llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ (key_mask_value,
1708+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
1709+ builder->CreateCondBr (is_prob_not_neeeded, thenBB, elseBB);
1710+ builder->SetInsertPoint (thenBB);
1711+ {
1712+ // A single by value comparison is needed even though
1713+ // we don't need to do linear probing. This is because
1714+ // the user can provide a key which is absent in the dict
1715+ // but is giving the same hash value as one of the keys present in the dict.
1716+ // In the above case we will end up returning value for a key
1717+ // which is not present in the dict. Instead we should return an error
1718+ // which is done in the below code.
1719+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
1720+ llvm_utils->list_api ->read_item (key_list, key_hash, false , module ,
1721+ LLVM::is_llvm_struct (key_asr_type)), module , key_asr_type);
1722+
1723+ llvm_utils->create_if_else (is_key_matching, [=]() {
1724+ LLVM::CreateStore (*builder, key_hash, pos_ptr);
1725+ }, [=]() {
1726+ });
1727+ }
1728+ builder->CreateBr (mergeBB);
1729+ llvm_utils->start_new_block (elseBB);
1730+ {
1731+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask,
1732+ module , key_asr_type, true );
1733+ }
1734+ llvm_utils->start_new_block (mergeBB);
1735+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
1736+ llvm::Value* item = llvm_utils->list_api ->read_item (value_list, pos,
1737+ false , module , true );
16571738 return item;
16581739 }
16591740
@@ -1709,6 +1790,36 @@ namespace LCompilers {
17091790 }
17101791
17111792 llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read (
1793+ llvm::Value* dict, llvm::Value* key_hash,
1794+ llvm::Value* key, llvm::Module& module ,
1795+ ASR::ttype_t * key_asr_type, ASR::ttype_t * value_asr_type) {
1796+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
1797+ llvm::Value* key_value_pairs = LLVM::CreateLoad (*builder, get_pointer_to_key_value_pairs (dict));
1798+ llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep (key_value_pairs, key_hash);
1799+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
1800+ llvm::Type* kv_struct_type = get_key_value_pair_type (key_asr_type, value_asr_type);
1801+ this ->resolve_collision (capacity, key_hash, key, key_value_pair_linked_list,
1802+ kv_struct_type, key_mask, module , key_asr_type);
1803+ std::pair<std::string, std::string> llvm_key = std::make_pair (
1804+ ASRUtils::get_type_code (key_asr_type),
1805+ ASRUtils::get_type_code (value_asr_type)
1806+ );
1807+ llvm::Type* value_type = std::get<2 >(typecode2dicttype[llvm_key]).second ;
1808+ llvm::Value* tmp_value_ptr_local = nullptr ;
1809+ if ( !are_iterators_set ) {
1810+ tmp_value_ptr = builder->CreateAlloca (value_type, nullptr );
1811+ tmp_value_ptr_local = tmp_value_ptr;
1812+ } else {
1813+ tmp_value_ptr_local = builder->CreateBitCast (tmp_value_ptr, value_type->getPointerTo ());
1814+ }
1815+ llvm::Value* kv_struct_i8 = LLVM::CreateLoad (*builder, chain_itr);
1816+ llvm::Value* kv_struct = builder->CreateBitCast (kv_struct_i8, kv_struct_type->getPointerTo ());
1817+ llvm::Value* value = LLVM::CreateLoad (*builder, llvm_utils->create_gep (kv_struct, 1 ));
1818+ LLVM::CreateStore (*builder, value, tmp_value_ptr_local);
1819+ return tmp_value_ptr;
1820+ }
1821+
1822+ llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read_with_bound_check (
17121823 llvm::Value* dict, llvm::Value* key_hash,
17131824 llvm::Value* key, llvm::Module& module ,
17141825 ASR::ttype_t * key_asr_type, ASR::ttype_t * value_asr_type) {
@@ -2233,12 +2344,18 @@ namespace LCompilers {
22332344 }
22342345
22352346 llvm::Value* LLVMDict::read_item (llvm::Value* dict, llvm::Value* key,
2236- llvm::Module& module , ASR::Dict_t* dict_type,
2347+ llvm::Module& module , ASR::Dict_t* dict_type, bool enable_bounds_checking,
22372348 bool get_pointer) {
22382349 llvm::Value* current_capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
22392350 llvm::Value* key_hash = get_key_hash (current_capacity, key, dict_type->m_key_type , module );
2240- llvm::Value* value_ptr = this ->resolve_collision_for_read (dict, key_hash, key, module ,
2351+ llvm::Value* value_ptr;
2352+ if (enable_bounds_checking) {
2353+ value_ptr = this ->resolve_collision_for_read_with_bound_check (dict, key_hash, key, module ,
22412354 dict_type->m_key_type , dict_type->m_value_type );
2355+ } else {
2356+ value_ptr = this ->resolve_collision_for_read (dict, key_hash, key, module ,
2357+ dict_type->m_key_type , dict_type->m_value_type );
2358+ }
22422359 if ( get_pointer ) {
22432360 return value_ptr;
22442361 }
@@ -2260,11 +2377,17 @@ namespace LCompilers {
22602377 }
22612378
22622379 llvm::Value* LLVMDictSeparateChaining::read_item (llvm::Value* dict, llvm::Value* key,
2263- llvm::Module& module , ASR::Dict_t* dict_type, bool get_pointer) {
2380+ llvm::Module& module , ASR::Dict_t* dict_type, bool enable_bounds_checking, bool get_pointer) {
22642381 llvm::Value* current_capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
22652382 llvm::Value* key_hash = get_key_hash (current_capacity, key, dict_type->m_key_type , module );
2266- llvm::Value* value_ptr = this ->resolve_collision_for_read (dict, key_hash, key, module ,
2383+ llvm::Value* value_ptr;
2384+ if (enable_bounds_checking) {
2385+ value_ptr = this ->resolve_collision_for_read_with_bound_check (dict, key_hash, key, module ,
22672386 dict_type->m_key_type , dict_type->m_value_type );
2387+ } else {
2388+ value_ptr = this ->resolve_collision_for_read (dict, key_hash, key, module ,
2389+ dict_type->m_key_type , dict_type->m_value_type );
2390+ }
22682391 std::pair<std::string, std::string> llvm_key = std::make_pair (
22692392 ASRUtils::get_type_code (dict_type->m_key_type ),
22702393 ASRUtils::get_type_code (dict_type->m_value_type )
0 commit comments