Skip to content

Commit 6249f49

Browse files
committed
Use default value in LLVM and enable tests
1 parent ce5248a commit 6249f49

File tree

4 files changed

+57
-28
lines changed

4 files changed

+57
-28
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ RUN(NAME test_dict_07 LABELS cpython llvm)
295295
RUN(NAME test_dict_08 LABELS cpython llvm c)
296296
RUN(NAME test_dict_09 LABELS cpython llvm c)
297297
RUN(NAME test_dict_10 LABELS cpython llvm) # TODO: Add support of dict with string in C backend
298-
RUN(NAME test_dict_11 LABELS cpython c) # TODO: Add LLVM support
298+
RUN(NAME test_dict_11 LABELS cpython llvm c)
299299
RUN(NAME test_for_loop LABELS cpython llvm c)
300300
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
301301
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)

integration_tests/test_dict_11.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,13 @@ def test_dict_11():
44
num : dict[i32, i32]
55
num = {11: 22, 33: 44, 55: 66}
66
assert num.get(7, -1) == -1
7+
assert num.get(11, -1) == 22
8+
assert num.get(33, -1) == 44
9+
assert num.get(55, -1) == 66
10+
assert num.get(72, -110) == -110
11+
d : dict[i32, str]
12+
d = {1: "1", 2: "22", 3: "333"}
13+
assert d.get(2, "00") == "22"
14+
assert d.get(21, "nokey") == "nokey"
715

816
test_dict_11()

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1793,9 +1793,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17931793
this->visit_expr_wrapper(x.m_key, true);
17941794
ptr_loads = ptr_loads_copy;
17951795
llvm::Value *key = tmp;
1796+
llvm::Type *val_type = get_type_from_ttype_t_util(dict_type->m_value_type);
1797+
llvm::Value *def_value_ptr = nullptr;
1798+
if (x.m_default) {
1799+
def_value_ptr = builder->CreateAlloca(val_type, nullptr);
1800+
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_value_type);
1801+
this->visit_expr_wrapper(x.m_default, true);
1802+
ptr_loads = ptr_loads_copy;
1803+
builder->CreateStore(tmp, def_value_ptr);
1804+
}
17961805

17971806
set_dict_api(dict_type);
1798-
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type,
1807+
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type, def_value_ptr,
17991808
LLVM::is_llvm_struct(dict_type->m_value_type));
18001809
}
18011810

src/libasr/codegen/llvm_utils.cpp

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,7 @@ namespace LCompilers {
15481548
llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read(
15491549
llvm::Value* dict, llvm::Value* key_hash,
15501550
llvm::Value* key, llvm::Module& module,
1551-
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, llvm::Value *def_value) {
1551+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Value *def_value) {
15521552
llvm::Value* key_list = get_key_list(dict);
15531553
llvm::Value* value_list = get_value_list(dict);
15541554
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
@@ -1564,6 +1564,12 @@ namespace LCompilers {
15641564
llvm_utils->create_ptr_gep(key_mask, key_hash));
15651565
llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value,
15661566
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
1567+
std::pair<std::string, std::string> llvm_key = std::make_pair(
1568+
ASRUtils::get_type_code(key_asr_type),
1569+
ASRUtils::get_type_code(value_asr_type)
1570+
);
1571+
llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second;
1572+
llvm::Value* result = builder->CreateAlloca(value_type, nullptr);
15671573
builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB);
15681574
builder->SetInsertPoint(thenBB);
15691575
{
@@ -1583,36 +1589,42 @@ namespace LCompilers {
15831589
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type);
15841590
builder->CreateCondBr(is_key_matching, thenBB_single_match, elseBB_single_match);
15851591
builder->SetInsertPoint(thenBB_single_match);
1586-
LLVM::CreateStore(*builder, key_hash, pos_ptr);
1592+
{
1593+
LLVM::CreateStore(*builder, key_hash, pos_ptr);
1594+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
1595+
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos,
1596+
false, module, false);
1597+
LLVM::CreateStore(*builder, item, result);
1598+
}
15871599
builder->CreateBr(mergeBB_single_match);
15881600
llvm_utils->start_new_block(elseBB_single_match);
15891601
{
1590-
std::string message = "The dict does not contain the specified key";
1591-
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
1592-
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
1593-
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
1594-
int exit_code_int = 1;
1595-
llvm::Value *exit_code = llvm::ConstantInt::get(context,
1596-
llvm::APInt(32, exit_code_int));
1597-
exit(context, module, *builder, exit_code);
1602+
if (def_value != nullptr) {
1603+
LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), result);
1604+
} else {
1605+
std::string message = "The dict does not contain the specified key";
1606+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
1607+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
1608+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
1609+
int exit_code_int = 1;
1610+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
1611+
llvm::APInt(32, exit_code_int));
1612+
exit(context, module, *builder, exit_code);
1613+
}
15981614
}
15991615
llvm_utils->start_new_block(mergeBB_single_match);
16001616
}
16011617
builder->CreateBr(mergeBB);
16021618
llvm_utils->start_new_block(elseBB);
16031619
{
16041620
if (def_value != nullptr) {
1605-
llvm_utils->start_new_block(mergeBB);
1606-
return def_value;
1621+
LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), result);
16071622
}
16081623
this->resolve_collision(capacity, key_hash, key, key_list, key_mask,
16091624
module, key_asr_type, true);
16101625
}
16111626
llvm_utils->start_new_block(mergeBB);
1612-
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
1613-
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos,
1614-
false, module, true);
1615-
return item;
1627+
return result;
16161628
}
16171629

16181630
llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read(
@@ -1663,17 +1675,17 @@ namespace LCompilers {
16631675
llvm_utils->start_new_block(elseBB_single_match);
16641676
{
16651677
if (def_value != nullptr) {
1666-
llvm_utils->start_new_block(mergeBB_single_match);
1667-
return def_value;
1678+
LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), tmp_value_ptr_local);
1679+
} else {
1680+
std::string message = "The dict does not contain the specified key";
1681+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
1682+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
1683+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
1684+
int exit_code_int = 1;
1685+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
1686+
llvm::APInt(32, exit_code_int));
1687+
exit(context, module, *builder, exit_code);
16681688
}
1669-
std::string message = "The dict does not contain the specified key";
1670-
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
1671-
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
1672-
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
1673-
int exit_code_int = 1;
1674-
llvm::Value *exit_code = llvm::ConstantInt::get(context,
1675-
llvm::APInt(32, exit_code_int));
1676-
exit(context, module, *builder, exit_code);
16771689
}
16781690
llvm_utils->start_new_block(mergeBB_single_match);
16791691
return tmp_value_ptr;

0 commit comments

Comments
 (0)