Skip to content

Commit d67e28d

Browse files
committed
Add option for bound checking in dict
1 parent 46cde59 commit d67e28d

File tree

3 files changed

+150
-9
lines changed

3 files changed

+150
-9
lines changed

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,6 +1778,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17781778
} else {
17791779
set_dict_api(dict_type);
17801780
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type,
1781+
compiler_options.enable_bounds_checking,
17811782
LLVM::is_llvm_struct(dict_type->m_value_type));
17821783
}
17831784
}

src/libasr/codegen/llvm_utils.cpp

Lines changed: 129 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,36 @@ namespace LCompilers {
15451545
return item;
15461546
}
15471547

1548+
llvm::Value* LLVMDict::resolve_collision_for_read_with_bound_check(
1549+
llvm::Value* dict, llvm::Value* key_hash,
1550+
llvm::Value* key, llvm::Module& module,
1551+
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) {
1552+
llvm::Value* key_list = get_key_list(dict);
1553+
llvm::Value* value_list = get_value_list(dict);
1554+
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
1555+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
1556+
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type, true);
1557+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
1558+
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
1559+
llvm_utils->list_api->read_item(key_list, pos, false, module,
1560+
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type);
1561+
1562+
llvm_utils->create_if_else(is_key_matching, [&]() {
1563+
}, [&]() {
1564+
std::string message = "The dict does not contain the specified key";
1565+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
1566+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
1567+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
1568+
int exit_code_int = 1;
1569+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
1570+
llvm::APInt(32, exit_code_int));
1571+
exit(context, module, *builder, exit_code);
1572+
});
1573+
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos,
1574+
false, module, false);
1575+
return item;
1576+
}
1577+
15481578
void LLVMDict::_check_key_present_or_default(llvm::Module& module, llvm::Value *key, llvm::Value *key_list,
15491579
ASR::ttype_t* key_asr_type, llvm::Value *value_list, llvm::Value *pos,
15501580
llvm::Value *def_value, llvm::Value* &result) {
@@ -1582,7 +1612,7 @@ namespace LCompilers {
15821612
return result;
15831613
}
15841614

1585-
llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read(
1615+
llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check(
15861616
llvm::Value* dict, llvm::Value* key_hash,
15871617
llvm::Value* key, llvm::Module& module,
15881618
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) {
@@ -1653,7 +1683,58 @@ namespace LCompilers {
16531683
exit(context, module, *builder, exit_code);
16541684
});
16551685
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos,
1656-
false, module, true);;
1686+
false, module, true);
1687+
return item;
1688+
}
1689+
1690+
llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read(
1691+
llvm::Value* dict, llvm::Value* key_hash,
1692+
llvm::Value* key, llvm::Module& module,
1693+
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) {
1694+
llvm::Value* key_list = get_key_list(dict);
1695+
llvm::Value* value_list = get_value_list(dict);
1696+
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
1697+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
1698+
if( !are_iterators_set ) {
1699+
pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
1700+
}
1701+
llvm::Function *fn = builder->GetInsertBlock()->getParent();
1702+
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
1703+
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
1704+
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
1705+
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
1706+
llvm_utils->create_ptr_gep(key_mask, key_hash));
1707+
llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value,
1708+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
1709+
builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB);
1710+
builder->SetInsertPoint(thenBB);
1711+
{
1712+
// A single by value comparison is needed even though
1713+
// we don't need to do linear probing. This is because
1714+
// the user can provide a key which is absent in the dict
1715+
// but is giving the same hash value as one of the keys present in the dict.
1716+
// In the above case we will end up returning value for a key
1717+
// which is not present in the dict. Instead we should return an error
1718+
// which is done in the below code.
1719+
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
1720+
llvm_utils->list_api->read_item(key_list, key_hash, false, module,
1721+
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type);
1722+
1723+
llvm_utils->create_if_else(is_key_matching, [=]() {
1724+
LLVM::CreateStore(*builder, key_hash, pos_ptr);
1725+
}, [=]() {
1726+
});
1727+
}
1728+
builder->CreateBr(mergeBB);
1729+
llvm_utils->start_new_block(elseBB);
1730+
{
1731+
this->resolve_collision(capacity, key_hash, key, key_list, key_mask,
1732+
module, key_asr_type, true);
1733+
}
1734+
llvm_utils->start_new_block(mergeBB);
1735+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
1736+
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos,
1737+
false, module, true);
16571738
return item;
16581739
}
16591740

@@ -1709,6 +1790,36 @@ namespace LCompilers {
17091790
}
17101791

17111792
llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read(
1793+
llvm::Value* dict, llvm::Value* key_hash,
1794+
llvm::Value* key, llvm::Module& module,
1795+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) {
1796+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
1797+
llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict));
1798+
llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash);
1799+
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
1800+
llvm::Type* kv_struct_type = get_key_value_pair_type(key_asr_type, value_asr_type);
1801+
this->resolve_collision(capacity, key_hash, key, key_value_pair_linked_list,
1802+
kv_struct_type, key_mask, module, key_asr_type);
1803+
std::pair<std::string, std::string> llvm_key = std::make_pair(
1804+
ASRUtils::get_type_code(key_asr_type),
1805+
ASRUtils::get_type_code(value_asr_type)
1806+
);
1807+
llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second;
1808+
llvm::Value* tmp_value_ptr_local = nullptr;
1809+
if( !are_iterators_set ) {
1810+
tmp_value_ptr = builder->CreateAlloca(value_type, nullptr);
1811+
tmp_value_ptr_local = tmp_value_ptr;
1812+
} else {
1813+
tmp_value_ptr_local = builder->CreateBitCast(tmp_value_ptr, value_type->getPointerTo());
1814+
}
1815+
llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
1816+
llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo());
1817+
llvm::Value* value = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 1));
1818+
LLVM::CreateStore(*builder, value, tmp_value_ptr_local);
1819+
return tmp_value_ptr;
1820+
}
1821+
1822+
llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read_with_bound_check(
17121823
llvm::Value* dict, llvm::Value* key_hash,
17131824
llvm::Value* key, llvm::Module& module,
17141825
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) {
@@ -2233,12 +2344,18 @@ namespace LCompilers {
22332344
}
22342345

22352346
llvm::Value* LLVMDict::read_item(llvm::Value* dict, llvm::Value* key,
2236-
llvm::Module& module, ASR::Dict_t* dict_type,
2347+
llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking,
22372348
bool get_pointer) {
22382349
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
22392350
llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module);
2240-
llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module,
2351+
llvm::Value* value_ptr;
2352+
if (enable_bounds_checking) {
2353+
value_ptr = this->resolve_collision_for_read_with_bound_check(dict, key_hash, key, module,
22412354
dict_type->m_key_type, dict_type->m_value_type);
2355+
} else {
2356+
value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module,
2357+
dict_type->m_key_type, dict_type->m_value_type);
2358+
}
22422359
if( get_pointer ) {
22432360
return value_ptr;
22442361
}
@@ -2260,11 +2377,17 @@ namespace LCompilers {
22602377
}
22612378

22622379
llvm::Value* LLVMDictSeparateChaining::read_item(llvm::Value* dict, llvm::Value* key,
2263-
llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer) {
2380+
llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking, bool get_pointer) {
22642381
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
22652382
llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module);
2266-
llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module,
2383+
llvm::Value* value_ptr;
2384+
if (enable_bounds_checking) {
2385+
value_ptr = this->resolve_collision_for_read_with_bound_check(dict, key_hash, key, module,
22672386
dict_type->m_key_type, dict_type->m_value_type);
2387+
} else {
2388+
value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module,
2389+
dict_type->m_key_type, dict_type->m_value_type);
2390+
}
22682391
std::pair<std::string, std::string> llvm_key = std::make_pair(
22692392
ASRUtils::get_type_code(dict_type->m_key_type),
22702393
ASRUtils::get_type_code(dict_type->m_value_type)

src/libasr/codegen/llvm_utils.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,11 @@ namespace LCompilers {
376376
llvm::Value* key, llvm::Module& module,
377377
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) = 0;
378378

379+
virtual
380+
llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash,
381+
llvm::Value* key, llvm::Module& module,
382+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) = 0;
383+
379384
virtual
380385
llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash,
381386
llvm::Value* key, llvm::Module& module,
@@ -401,7 +406,7 @@ namespace LCompilers {
401406

402407
virtual
403408
llvm::Value* read_item(llvm::Value* dict, llvm::Value* key,
404-
llvm::Module& module, ASR::Dict_t* dict_type,
409+
llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking,
405410
bool get_pointer=false) = 0;
406411

407412
virtual
@@ -481,6 +486,10 @@ namespace LCompilers {
481486
llvm::Value* key, llvm::Module& module,
482487
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type);
483488

489+
llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash,
490+
llvm::Value* key, llvm::Module& module,
491+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type);
492+
484493
llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash,
485494
llvm::Value* key, llvm::Module& module,
486495
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
@@ -502,7 +511,7 @@ namespace LCompilers {
502511
std::map<std::string, std::map<std::string, int>>& name2memidx);
503512

504513
llvm::Value* read_item(llvm::Value* dict, llvm::Value* key,
505-
llvm::Module& module, ASR::Dict_t* key_asr_type,
514+
llvm::Module& module, ASR::Dict_t* key_asr_type, bool enable_bounds_checking,
506515
bool get_pointer=false);
507516

508517
llvm::Value* get_item(llvm::Value* dict, llvm::Value* key,
@@ -548,6 +557,10 @@ namespace LCompilers {
548557
llvm::Value* key, llvm::Module& module,
549558
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type);
550559

560+
llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash,
561+
llvm::Value* key, llvm::Module& module,
562+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type);
563+
551564
llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash,
552565
llvm::Value* key, llvm::Module& module,
553566
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
@@ -621,6 +634,10 @@ namespace LCompilers {
621634
llvm::Value* key, llvm::Module& module,
622635
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type);
623636

637+
llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash,
638+
llvm::Value* key, llvm::Module& module,
639+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type);
640+
624641
llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash,
625642
llvm::Value* key, llvm::Module& module,
626643
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
@@ -642,7 +659,7 @@ namespace LCompilers {
642659
std::map<std::string, std::map<std::string, int>>& name2memidx);
643660

644661
llvm::Value* read_item(llvm::Value* dict, llvm::Value* key,
645-
llvm::Module& module, ASR::Dict_t* dict_type,
662+
llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking,
646663
bool get_pointer=false);
647664

648665
llvm::Value* get_item(llvm::Value* dict, llvm::Value* key,

0 commit comments

Comments
 (0)