Skip to content

Commit c935c96

Browse files
committed
Use default value in LLVM and enable tests
1 parent ce5248a commit c935c96

File tree

4 files changed

+76
-28
lines changed

4 files changed

+76
-28
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ RUN(NAME test_dict_07 LABELS cpython llvm)
295295
RUN(NAME test_dict_08 LABELS cpython llvm c)
296296
RUN(NAME test_dict_09 LABELS cpython llvm c)
297297
RUN(NAME test_dict_10 LABELS cpython llvm) # TODO: Add support of dict with string in C backend
298-
RUN(NAME test_dict_11 LABELS cpython c) # TODO: Add LLVM support
298+
RUN(NAME test_dict_11 LABELS cpython llvm c)
299299
RUN(NAME test_for_loop LABELS cpython llvm c)
300300
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
301301
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)

integration_tests/test_dict_11.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,13 @@ def test_dict_11():
44
num : dict[i32, i32]
55
num = {11: 22, 33: 44, 55: 66}
66
assert num.get(7, -1) == -1
7+
assert num.get(11, -1) == 22
8+
assert num.get(33, -1) == 44
9+
assert num.get(55, -1) == 66
10+
assert num.get(72, -110) == -110
11+
d : dict[i32, str]
12+
d = {1: "1", 2: "22", 3: "333"}
13+
assert d.get(2, "00") == "22"
14+
assert d.get(21, "nokey") == "nokey"
715

816
test_dict_11()

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1793,9 +1793,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17931793
this->visit_expr_wrapper(x.m_key, true);
17941794
ptr_loads = ptr_loads_copy;
17951795
llvm::Value *key = tmp;
1796+
llvm::Type *val_type = get_type_from_ttype_t_util(dict_type->m_value_type);
1797+
llvm::Value *def_value_ptr = nullptr;
1798+
if (x.m_default) {
1799+
def_value_ptr = builder->CreateAlloca(val_type, nullptr);
1800+
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_value_type);
1801+
this->visit_expr_wrapper(x.m_default, true);
1802+
ptr_loads = ptr_loads_copy;
1803+
builder->CreateStore(tmp, def_value_ptr);
1804+
}
17961805

17971806
set_dict_api(dict_type);
1798-
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type,
1807+
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type, def_value_ptr,
17991808
LLVM::is_llvm_struct(dict_type->m_value_type));
18001809
}
18011810

src/libasr/codegen/llvm_utils.cpp

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,7 @@ namespace LCompilers {
15481548
llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read(
15491549
llvm::Value* dict, llvm::Value* key_hash,
15501550
llvm::Value* key, llvm::Module& module,
1551-
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, llvm::Value *def_value) {
1551+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Value *def_value) {
15521552
llvm::Value* key_list = get_key_list(dict);
15531553
llvm::Value* value_list = get_value_list(dict);
15541554
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
@@ -1564,6 +1564,12 @@ namespace LCompilers {
15641564
llvm_utils->create_ptr_gep(key_mask, key_hash));
15651565
llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value,
15661566
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
1567+
std::pair<std::string, std::string> llvm_key = std::make_pair(
1568+
ASRUtils::get_type_code(key_asr_type),
1569+
ASRUtils::get_type_code(value_asr_type)
1570+
);
1571+
llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second;
1572+
llvm::Value* result = builder->CreateAlloca(value_type, nullptr);
15671573
builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB);
15681574
builder->SetInsertPoint(thenBB);
15691575
{
@@ -1587,32 +1593,57 @@ namespace LCompilers {
15871593
builder->CreateBr(mergeBB_single_match);
15881594
llvm_utils->start_new_block(elseBB_single_match);
15891595
{
1590-
std::string message = "The dict does not contain the specified key";
1591-
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
1592-
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
1593-
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
1594-
int exit_code_int = 1;
1595-
llvm::Value *exit_code = llvm::ConstantInt::get(context,
1596-
llvm::APInt(32, exit_code_int));
1597-
exit(context, module, *builder, exit_code);
1596+
if (def_value != nullptr) {
1597+
LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), result);
1598+
} else {
1599+
std::string message = "The dict does not contain the specified key";
1600+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
1601+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
1602+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
1603+
int exit_code_int = 1;
1604+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
1605+
llvm::APInt(32, exit_code_int));
1606+
exit(context, module, *builder, exit_code);
1607+
}
15981608
}
15991609
llvm_utils->start_new_block(mergeBB_single_match);
16001610
}
16011611
builder->CreateBr(mergeBB);
16021612
llvm_utils->start_new_block(elseBB);
16031613
{
1604-
if (def_value != nullptr) {
1605-
llvm_utils->start_new_block(mergeBB);
1606-
return def_value;
1607-
}
16081614
this->resolve_collision(capacity, key_hash, key, key_list, key_mask,
16091615
module, key_asr_type, true);
16101616
}
16111617
llvm_utils->start_new_block(mergeBB);
16121618
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
1613-
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos,
1614-
false, module, true);
1615-
return item;
1619+
1620+
if (def_value != nullptr) {
1621+
llvm::Function *fn_single_match = builder->GetInsertBlock()->getParent();
1622+
llvm::BasicBlock *thenBB_single_match = llvm::BasicBlock::Create(context, "then", fn_single_match);
1623+
llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else");
1624+
llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont");
1625+
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
1626+
llvm_utils->list_api->read_item(key_list, pos, false, module,
1627+
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type);
1628+
builder->CreateCondBr(is_key_matching, thenBB_single_match, elseBB_single_match);
1629+
builder->SetInsertPoint(thenBB_single_match);
1630+
{
1631+
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos,
1632+
false, module, false);
1633+
LLVM::CreateStore(*builder, item, result);
1634+
}
1635+
builder->CreateBr(mergeBB_single_match);
1636+
llvm_utils->start_new_block(elseBB_single_match);
1637+
{
1638+
LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), result);
1639+
}
1640+
llvm_utils->start_new_block(mergeBB_single_match);
1641+
} else {
1642+
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos,
1643+
false, module, false);
1644+
LLVM::CreateStore(*builder, item, result);
1645+
}
1646+
return result;
16161647
}
16171648

16181649
llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read(
@@ -1663,17 +1694,17 @@ namespace LCompilers {
16631694
llvm_utils->start_new_block(elseBB_single_match);
16641695
{
16651696
if (def_value != nullptr) {
1666-
llvm_utils->start_new_block(mergeBB_single_match);
1667-
return def_value;
1697+
LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), tmp_value_ptr_local);
1698+
} else {
1699+
std::string message = "The dict does not contain the specified key";
1700+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
1701+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
1702+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
1703+
int exit_code_int = 1;
1704+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
1705+
llvm::APInt(32, exit_code_int));
1706+
exit(context, module, *builder, exit_code);
16681707
}
1669-
std::string message = "The dict does not contain the specified key";
1670-
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
1671-
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
1672-
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
1673-
int exit_code_int = 1;
1674-
llvm::Value *exit_code = llvm::ConstantInt::get(context,
1675-
llvm::APInt(32, exit_code_int));
1676-
exit(context, module, *builder, exit_code);
16771708
}
16781709
llvm_utils->start_new_block(mergeBB_single_match);
16791710
return tmp_value_ptr;

0 commit comments

Comments
 (0)