Skip to content

Commit ce6cae0

Browse files
authored
C/LLVM: Support for dict.get (#1700)
1 parent 7e6c1b9 commit ce6cae0

11 files changed

+503
-76
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ RUN(NAME test_dict_07 LABELS cpython llvm)
361361
RUN(NAME test_dict_08 LABELS cpython llvm c)
362362
RUN(NAME test_dict_09 LABELS cpython llvm c)
363363
RUN(NAME test_dict_10 LABELS cpython llvm) # TODO: Add support of dict with string in C backend
364+
RUN(NAME test_dict_11 LABELS cpython llvm c)
364365
RUN(NAME test_for_loop LABELS cpython llvm c)
365366
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
366367
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)

integration_tests/test_dict_11.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from lpython import i32
2+
3+
def test_dict_11():
4+
num : dict[i32, i32]
5+
num = {11: 22, 33: 44, 55: 66}
6+
assert num.get(7, -1) == -1
7+
assert num.get(11, -1) == 22
8+
assert num.get(33, -1) == 44
9+
assert num.get(55, -1) == 66
10+
assert num.get(72, -110) == -110
11+
d : dict[i32, str]
12+
d = {1: "1", 2: "22", 3: "333"}
13+
assert d.get(2, "00") == "22"
14+
assert d.get(21, "nokey") == "nokey"
15+
16+
test_dict_11()

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,15 +1065,22 @@ R"(#include <stdio.h>
10651065
void visit_DictItem(const ASR::DictItem_t& x) {
10661066
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
10671067
ASRUtils::expr_type(x.m_a));
1068-
std::string dict_get_fun = c_ds_api->get_dict_get_func(dict_type);
1069-
10701068
this->visit_expr(*x.m_a);
10711069
std::string d_var = std::move(src);
10721070

10731071
this->visit_expr(*x.m_key);
10741072
std::string k = std::move(src);
10751073

1076-
src = dict_get_fun + "(&" + d_var + ", " + k + ")";
1074+
if (x.m_default) {
1075+
this->visit_expr(*x.m_default);
1076+
std::string def_value = std::move(src);
1077+
std::string dict_get_fun = c_ds_api->get_dict_get_func(dict_type,
1078+
true);
1079+
src = dict_get_fun + "(&" + d_var + ", " + k + ", " + def_value + ")";
1080+
} else {
1081+
std::string dict_get_fun = c_ds_api->get_dict_get_func(dict_type);
1082+
src = dict_get_fun + "(&" + d_var + ", " + k + ")";
1083+
}
10771084
}
10781085

10791086
void visit_ListAppend(const ASR::ListAppend_t& x) {

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -281,34 +281,6 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
281281
builder->SetInsertPoint(bb);
282282
}
283283

284-
// Note: `create_if_else` and `create_loop` are optional APIs
285-
// that do not have to be used. Many times, for more complicated
286-
// things, it might be more readable to just use the LLVM API
287-
// without any extra layer on top. In some other cases, it might
288-
// be more readable to use this abstraction.
289-
// The `if_block` and `else_block` must generate one or more blocks. In
290-
// addition, the `if_block` must not be terminated, we terminate it
291-
// ourselves. The `else_block` can be either terminated or not.
292-
template <typename IF, typename ELSE>
293-
void create_if_else(llvm::Value * cond, IF if_block, ELSE else_block) {
294-
llvm::Function *fn = builder->GetInsertBlock()->getParent();
295-
296-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
297-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
298-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
299-
300-
builder->CreateCondBr(cond, thenBB, elseBB);
301-
builder->SetInsertPoint(thenBB); {
302-
if_block();
303-
}
304-
builder->CreateBr(mergeBB);
305-
306-
start_new_block(elseBB); {
307-
else_block();
308-
}
309-
start_new_block(mergeBB);
310-
}
311-
312284
template <typename Cond, typename Body>
313285
void create_loop(char *name, Cond condition, Body loop_body) {
314286
dict_api_lp->set_iterators();
@@ -1487,7 +1459,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
14871459
}
14881460
}
14891461
llvm::Value *cond = arr_descr->get_is_allocated_flag(tmp);
1490-
create_if_else(cond, [=]() {
1462+
llvm_utils->create_if_else(cond, [=]() {
14911463
call_lfortran_free(free_fn);
14921464
}, [](){});
14931465
}
@@ -1793,10 +1765,22 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17931765
this->visit_expr_wrapper(x.m_key, true);
17941766
ptr_loads = ptr_loads_copy;
17951767
llvm::Value *key = tmp;
1796-
1797-
set_dict_api(dict_type);
1798-
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type,
1768+
if (x.m_default) {
1769+
llvm::Type *val_type = get_type_from_ttype_t_util(dict_type->m_value_type);
1770+
llvm::Value *def_value_ptr = builder->CreateAlloca(val_type, nullptr);
1771+
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_value_type);
1772+
this->visit_expr_wrapper(x.m_default, true);
1773+
ptr_loads = ptr_loads_copy;
1774+
builder->CreateStore(tmp, def_value_ptr);
1775+
set_dict_api(dict_type);
1776+
tmp = llvm_utils->dict_api->get_item(pdict, key, *module, dict_type, def_value_ptr,
17991777
LLVM::is_llvm_struct(dict_type->m_value_type));
1778+
} else {
1779+
set_dict_api(dict_type);
1780+
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type,
1781+
compiler_options.enable_bounds_checking,
1782+
LLVM::is_llvm_struct(dict_type->m_value_type));
1783+
}
18001784
}
18011785

18021786
void visit_DictPop(const ASR::DictPop_t& x) {
@@ -5140,7 +5124,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
51405124

51415125
void visit_If(const ASR::If_t &x) {
51425126
this->visit_expr_wrapper(x.m_test, true);
5143-
create_if_else(tmp, [=]() {
5127+
llvm_utils->create_if_else(tmp, [=]() {
51445128
for (size_t i=0; i<x.n_body; i++) {
51455129
this->visit_stmt(*x.m_body[i]);
51465130
}
@@ -5157,7 +5141,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
51575141
llvm::Value *cond = tmp;
51585142
llvm::Value *then_val = nullptr;
51595143
llvm::Value *else_val = nullptr;
5160-
create_if_else(cond, [=, &then_val]() {
5144+
llvm_utils->create_if_else(cond, [=, &then_val]() {
51615145
this->visit_expr_wrapper(x.m_body, true);
51625146
then_val = tmp;
51635147
}, [=, &else_val]() {
@@ -5313,7 +5297,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
53135297
}
53145298
switch (x.m_op) {
53155299
case ASR::logicalbinopType::And: {
5316-
create_if_else(cond, [&, result, left_val]() {
5300+
llvm_utils->create_if_else(cond, [&, result, left_val]() {
53175301
LLVM::CreateStore(*builder, left_val, result);
53185302
}, [&, result, right_val]() {
53195303
LLVM::CreateStore(*builder, right_val, result);
@@ -5322,7 +5306,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
53225306
break;
53235307
};
53245308
case ASR::logicalbinopType::Or: {
5325-
create_if_else(cond, [&, result, right_val]() {
5309+
llvm_utils->create_if_else(cond, [&, result, right_val]() {
53265310
LLVM::CreateStore(*builder, right_val, result);
53275311

53285312
}, [&, result, left_val]() {
@@ -5854,7 +5838,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
58545838
void visit_Assert(const ASR::Assert_t &x) {
58555839
if (compiler_options.emit_debug_info) debug_emit_loc(x);
58565840
this->visit_expr_wrapper(x.m_test, true);
5857-
create_if_else(tmp, []() {}, [=]() {
5841+
llvm_utils->create_if_else(tmp, []() {}, [=]() {
58585842
if (compiler_options.emit_debug_info) {
58595843
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr(infile);
58605844
llvm::Value *fmt_ptr1 = llvm::ConstantInt::get(context, llvm::APInt(

src/libasr/codegen/c_utils.h

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,8 +1126,11 @@ class CCPPDSUtils {
11261126
return typecodeToDSfuncs[dict_type_code]["dict_insert"];
11271127
}
11281128

1129-
std::string get_dict_get_func(ASR::Dict_t* d_type) {
1129+
std::string get_dict_get_func(ASR::Dict_t* d_type, bool with_fallback=false) {
11301130
std::string dict_type_code = ASRUtils::get_type_code((ASR::ttype_t*)d_type, true);
1131+
if (with_fallback) {
1132+
return typecodeToDSfuncs[dict_type_code]["dict_get_fb"];
1133+
}
11311134
return typecodeToDSfuncs[dict_type_code]["dict_get"];
11321135
}
11331136

@@ -1177,6 +1180,7 @@ class CCPPDSUtils {
11771180
dict_resize(dict_type, dict_struct_type, dict_type_code);
11781181
dict_insert(dict_type, dict_struct_type, dict_type_code);
11791182
dict_get_item(dict_type, dict_struct_type, dict_type_code);
1183+
dict_get_item_with_fallback(dict_type, dict_struct_type, dict_type_code);
11801184
dict_len(dict_type, dict_struct_type, dict_type_code);
11811185
dict_pop(dict_type, dict_struct_type, dict_type_code);
11821186
dict_deepcopy(dict_type, dict_struct_type, dict_type_code);
@@ -1294,11 +1298,31 @@ class CCPPDSUtils {
12941298
generated_code += indent + tab + "int j=k\%x->capacity, c = 0;\n";
12951299
generated_code += indent + tab + "while(c<x->capacity && x->present[j] && !(x->key[j] == k)) j=(j+1)\%x->capacity, c++;\n";
12961300
generated_code += indent + tab + "if (x->present[j] && x->key[j] == k) return x->value[j];\n";
1297-
generated_code += indent + tab + "printf(\"Key not found\");\n";
1301+
generated_code += indent + tab + "printf(\"Key not found\\n\");\n";
12981302
generated_code += indent + tab + "exit(1);\n";
12991303
generated_code += indent + "}\n\n";
13001304
}
13011305

1306+
void dict_get_item_with_fallback(ASR::Dict_t *dict_type, std::string dict_struct_type,
1307+
std::string dict_type_code) {
1308+
std::string indent(indentation_level * indentation_spaces, ' ');
1309+
std::string tab(indentation_spaces, ' ');
1310+
std::string dict_get_func = global_scope->get_unique_name("dict_get_item_fb_" + dict_type_code);
1311+
typecodeToDSfuncs[dict_type_code]["dict_get_fb"] = dict_get_func;
1312+
std::string key = CUtils::get_c_type_from_ttype_t(dict_type->m_key_type);
1313+
std::string val = CUtils::get_c_type_from_ttype_t(dict_type->m_value_type);
1314+
std::string signature = val + " " + dict_get_func + "(" + dict_struct_type + "* x, " +\
1315+
key + " k, " + val + " dv)";
1316+
func_decls += indent + "inline " + signature + ";\n";
1317+
signature = indent + signature;
1318+
generated_code += indent + signature + " {\n";
1319+
generated_code += indent + tab + "int j=k\%x->capacity, c = 0;\n";
1320+
generated_code += indent + tab + "while(c<x->capacity && x->present[j] && !(x->key[j] == k)) j=(j+1)\%x->capacity, c++;\n";
1321+
generated_code += indent + tab + "if (x->present[j] && x->key[j] == k) return x->value[j];\n";
1322+
generated_code += indent + tab + "return dv;\n";
1323+
generated_code += indent + "}\n\n";
1324+
}
1325+
13021326
void dict_len(ASR::Dict_t *dict_type, std::string dict_struct_type,
13031327
std::string dict_type_code) {
13041328
std::string indent(indentation_level * indentation_spaces, ' ');

0 commit comments

Comments
 (0)