@@ -1545,6 +1545,36 @@ namespace LCompilers {
1545
1545
return item;
1546
1546
}
1547
1547
1548
+ llvm::Value* LLVMDict::resolve_collision_for_read_with_bound_check (
1549
+ llvm::Value* dict, llvm::Value* key_hash,
1550
+ llvm::Value* key, llvm::Module& module,
1551
+ ASR::ttype_t * key_asr_type, ASR::ttype_t * /* value_asr_type*/ ) {
1552
+ llvm::Value* key_list = get_key_list (dict);
1553
+ llvm::Value* value_list = get_value_list (dict);
1554
+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
1555
+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
1556
+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask, module, key_asr_type, true );
1557
+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
1558
+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
1559
+ llvm_utils->list_api ->read_item (key_list, pos, false , module,
1560
+ LLVM::is_llvm_struct (key_asr_type)), module, key_asr_type);
1561
+
1562
+ llvm_utils->create_if_else (is_key_matching, [&]() {
1563
+ }, [&]() {
1564
+ std::string message = " The dict does not contain the specified key" ;
1565
+ llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr (" KeyError: %s\n " );
1566
+ llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr (message);
1567
+ print_error (context, module, *builder, {fmt_ptr, fmt_ptr2});
1568
+ int exit_code_int = 1 ;
1569
+ llvm::Value *exit_code = llvm::ConstantInt::get (context,
1570
+ llvm::APInt (32 , exit_code_int));
1571
+ exit (context, module, *builder, exit_code);
1572
+ });
1573
+ llvm::Value* item = llvm_utils->list_api ->read_item (value_list, pos,
1574
+ false , module, false );
1575
+ return item;
1576
+ }
1577
+
1548
1578
void LLVMDict::_check_key_present_or_default (llvm::Module& module, llvm::Value *key, llvm::Value *key_list,
1549
1579
ASR::ttype_t * key_asr_type, llvm::Value *value_list, llvm::Value *pos,
1550
1580
llvm::Value *def_value, llvm::Value* &result) {
@@ -1582,7 +1612,7 @@ namespace LCompilers {
1582
1612
return result;
1583
1613
}
1584
1614
1585
- llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read (
1615
+ llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check (
1586
1616
llvm::Value* dict, llvm::Value* key_hash,
1587
1617
llvm::Value* key, llvm::Module& module,
1588
1618
ASR::ttype_t * key_asr_type, ASR::ttype_t * /* value_asr_type*/ ) {
@@ -1653,7 +1683,58 @@ namespace LCompilers {
1653
1683
exit (context, module, *builder, exit_code);
1654
1684
});
1655
1685
llvm::Value* item = llvm_utils->list_api ->read_item (value_list, pos,
1656
- false , module, true );;
1686
+ false , module, true );
1687
+ return item;
1688
+ }
1689
+
1690
+ llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read (
1691
+ llvm::Value* dict, llvm::Value* key_hash,
1692
+ llvm::Value* key, llvm::Module& module,
1693
+ ASR::ttype_t * key_asr_type, ASR::ttype_t * /* value_asr_type*/ ) {
1694
+ llvm::Value* key_list = get_key_list (dict);
1695
+ llvm::Value* value_list = get_value_list (dict);
1696
+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
1697
+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
1698
+ if ( !are_iterators_set ) {
1699
+ pos_ptr = builder->CreateAlloca (llvm::Type::getInt32Ty (context), nullptr );
1700
+ }
1701
+ llvm::Function *fn = builder->GetInsertBlock ()->getParent ();
1702
+ llvm::BasicBlock *thenBB = llvm::BasicBlock::Create (context, " then" , fn);
1703
+ llvm::BasicBlock *elseBB = llvm::BasicBlock::Create (context, " else" );
1704
+ llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create (context, " ifcont" );
1705
+ llvm::Value* key_mask_value = LLVM::CreateLoad (*builder,
1706
+ llvm_utils->create_ptr_gep (key_mask, key_hash));
1707
+ llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ (key_mask_value,
1708
+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
1709
+ builder->CreateCondBr (is_prob_not_neeeded, thenBB, elseBB);
1710
+ builder->SetInsertPoint (thenBB);
1711
+ {
1712
+ // A single by value comparison is needed even though
1713
+ // we don't need to do linear probing. This is because
1714
+ // the user can provide a key which is absent in the dict
1715
+ // but is giving the same hash value as one of the keys present in the dict.
1716
+ // In the above case we will end up returning value for a key
1717
+ // which is not present in the dict. Instead we should return an error
1718
+ // which is done in the below code.
1719
+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
1720
+ llvm_utils->list_api ->read_item (key_list, key_hash, false , module,
1721
+ LLVM::is_llvm_struct (key_asr_type)), module, key_asr_type);
1722
+
1723
+ llvm_utils->create_if_else (is_key_matching, [=]() {
1724
+ LLVM::CreateStore (*builder, key_hash, pos_ptr);
1725
+ }, [=]() {
1726
+ });
1727
+ }
1728
+ builder->CreateBr (mergeBB);
1729
+ llvm_utils->start_new_block (elseBB);
1730
+ {
1731
+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask,
1732
+ module, key_asr_type, true );
1733
+ }
1734
+ llvm_utils->start_new_block (mergeBB);
1735
+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
1736
+ llvm::Value* item = llvm_utils->list_api ->read_item (value_list, pos,
1737
+ false , module, true );
1657
1738
return item;
1658
1739
}
1659
1740
@@ -1709,6 +1790,36 @@ namespace LCompilers {
1709
1790
}
1710
1791
1711
1792
llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read (
1793
+ llvm::Value* dict, llvm::Value* key_hash,
1794
+ llvm::Value* key, llvm::Module& module,
1795
+ ASR::ttype_t * key_asr_type, ASR::ttype_t * value_asr_type) {
1796
+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
1797
+ llvm::Value* key_value_pairs = LLVM::CreateLoad (*builder, get_pointer_to_key_value_pairs (dict));
1798
+ llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep (key_value_pairs, key_hash);
1799
+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
1800
+ llvm::Type* kv_struct_type = get_key_value_pair_type (key_asr_type, value_asr_type);
1801
+ this ->resolve_collision (capacity, key_hash, key, key_value_pair_linked_list,
1802
+ kv_struct_type, key_mask, module, key_asr_type);
1803
+ std::pair<std::string, std::string> llvm_key = std::make_pair (
1804
+ ASRUtils::get_type_code (key_asr_type),
1805
+ ASRUtils::get_type_code (value_asr_type)
1806
+ );
1807
+ llvm::Type* value_type = std::get<2 >(typecode2dicttype[llvm_key]).second ;
1808
+ llvm::Value* tmp_value_ptr_local = nullptr ;
1809
+ if ( !are_iterators_set ) {
1810
+ tmp_value_ptr = builder->CreateAlloca (value_type, nullptr );
1811
+ tmp_value_ptr_local = tmp_value_ptr;
1812
+ } else {
1813
+ tmp_value_ptr_local = builder->CreateBitCast (tmp_value_ptr, value_type->getPointerTo ());
1814
+ }
1815
+ llvm::Value* kv_struct_i8 = LLVM::CreateLoad (*builder, chain_itr);
1816
+ llvm::Value* kv_struct = builder->CreateBitCast (kv_struct_i8, kv_struct_type->getPointerTo ());
1817
+ llvm::Value* value = LLVM::CreateLoad (*builder, llvm_utils->create_gep (kv_struct, 1 ));
1818
+ LLVM::CreateStore (*builder, value, tmp_value_ptr_local);
1819
+ return tmp_value_ptr;
1820
+ }
1821
+
1822
+ llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read_with_bound_check (
1712
1823
llvm::Value* dict, llvm::Value* key_hash,
1713
1824
llvm::Value* key, llvm::Module& module,
1714
1825
ASR::ttype_t * key_asr_type, ASR::ttype_t * value_asr_type) {
@@ -2233,12 +2344,18 @@ namespace LCompilers {
2233
2344
}
2234
2345
2235
2346
llvm::Value* LLVMDict::read_item (llvm::Value* dict, llvm::Value* key,
2236
- llvm::Module& module, ASR::Dict_t* dict_type,
2347
+ llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking,
2237
2348
bool get_pointer) {
2238
2349
llvm::Value* current_capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
2239
2350
llvm::Value* key_hash = get_key_hash (current_capacity, key, dict_type->m_key_type , module);
2240
- llvm::Value* value_ptr = this ->resolve_collision_for_read (dict, key_hash, key, module,
2351
+ llvm::Value* value_ptr;
2352
+ if (enable_bounds_checking) {
2353
+ value_ptr = this ->resolve_collision_for_read_with_bound_check (dict, key_hash, key, module,
2241
2354
dict_type->m_key_type , dict_type->m_value_type );
2355
+ } else {
2356
+ value_ptr = this ->resolve_collision_for_read (dict, key_hash, key, module,
2357
+ dict_type->m_key_type , dict_type->m_value_type );
2358
+ }
2242
2359
if ( get_pointer ) {
2243
2360
return value_ptr;
2244
2361
}
@@ -2260,11 +2377,17 @@ namespace LCompilers {
2260
2377
}
2261
2378
2262
2379
llvm::Value* LLVMDictSeparateChaining::read_item (llvm::Value* dict, llvm::Value* key,
2263
- llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer) {
2380
+ llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking, bool get_pointer) {
2264
2381
llvm::Value* current_capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
2265
2382
llvm::Value* key_hash = get_key_hash (current_capacity, key, dict_type->m_key_type , module);
2266
- llvm::Value* value_ptr = this ->resolve_collision_for_read (dict, key_hash, key, module,
2383
+ llvm::Value* value_ptr;
2384
+ if (enable_bounds_checking) {
2385
+ value_ptr = this ->resolve_collision_for_read_with_bound_check (dict, key_hash, key, module,
2267
2386
dict_type->m_key_type , dict_type->m_value_type );
2387
+ } else {
2388
+ value_ptr = this ->resolve_collision_for_read (dict, key_hash, key, module,
2389
+ dict_type->m_key_type , dict_type->m_value_type );
2390
+ }
2268
2391
std::pair<std::string, std::string> llvm_key = std::make_pair (
2269
2392
ASRUtils::get_type_code (dict_type->m_key_type ),
2270
2393
ASRUtils::get_type_code (dict_type->m_value_type )
0 commit comments