Skip to content

C/LLVM: Fixes with dict get #1700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ RUN(NAME test_dict_07 LABELS cpython llvm)
RUN(NAME test_dict_08 LABELS cpython llvm c)
RUN(NAME test_dict_09 LABELS cpython llvm c)
RUN(NAME test_dict_10 LABELS cpython llvm) # TODO: Add support of dict with string in C backend
RUN(NAME test_dict_11 LABELS cpython llvm c)
RUN(NAME test_for_loop LABELS cpython llvm c)
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Expand Down
16 changes: 16 additions & 0 deletions integration_tests/test_dict_11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from lpython import i32

def test_dict_11():
num : dict[i32, i32]
num = {11: 22, 33: 44, 55: 66}
assert num.get(7, -1) == -1
assert num.get(11, -1) == 22
assert num.get(33, -1) == 44
assert num.get(55, -1) == 66
assert num.get(72, -110) == -110
d : dict[i32, str]
d = {1: "1", 2: "22", 3: "333"}
assert d.get(2, "00") == "22"
assert d.get(21, "nokey") == "nokey"

test_dict_11()
13 changes: 10 additions & 3 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1065,15 +1065,22 @@ R"(#include <stdio.h>
void visit_DictItem(const ASR::DictItem_t& x) {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
ASRUtils::expr_type(x.m_a));
std::string dict_get_fun = c_ds_api->get_dict_get_func(dict_type);

this->visit_expr(*x.m_a);
std::string d_var = std::move(src);

this->visit_expr(*x.m_key);
std::string k = std::move(src);

src = dict_get_fun + "(&" + d_var + ", " + k + ")";
if (x.m_default) {
this->visit_expr(*x.m_default);
std::string def_value = std::move(src);
std::string dict_get_fun = c_ds_api->get_dict_get_func(dict_type,
true);
src = dict_get_fun + "(&" + d_var + ", " + k + ", " + def_value + ")";
} else {
std::string dict_get_fun = c_ds_api->get_dict_get_func(dict_type);
src = dict_get_fun + "(&" + d_var + ", " + k + ")";
}
}

void visit_ListAppend(const ASR::ListAppend_t& x) {
Expand Down
58 changes: 21 additions & 37 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,34 +281,6 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
builder->SetInsertPoint(bb);
}

// Note: `create_if_else` and `create_loop` are optional APIs
// that do not have to be used. Many times, for more complicated
// things, it might be more readable to just use the LLVM API
// without any extra layer on top. In some other cases, it might
// be more readable to use this abstraction.
// The `if_block` and `else_block` must generate one or more blocks. In
// addition, the `if_block` must not be terminated, we terminate it
// ourselves. The `else_block` can be either terminated or not.
template <typename IF, typename ELSE>
void create_if_else(llvm::Value * cond, IF if_block, ELSE else_block) {
llvm::Function *fn = builder->GetInsertBlock()->getParent();

llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");

builder->CreateCondBr(cond, thenBB, elseBB);
builder->SetInsertPoint(thenBB); {
if_block();
}
builder->CreateBr(mergeBB);

start_new_block(elseBB); {
else_block();
}
start_new_block(mergeBB);
}

template <typename Cond, typename Body>
void create_loop(char *name, Cond condition, Body loop_body) {
dict_api_lp->set_iterators();
Expand Down Expand Up @@ -1487,7 +1459,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}
llvm::Value *cond = arr_descr->get_is_allocated_flag(tmp);
create_if_else(cond, [=]() {
llvm_utils->create_if_else(cond, [=]() {
call_lfortran_free(free_fn);
}, [](){});
}
Expand Down Expand Up @@ -1793,10 +1765,22 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr_wrapper(x.m_key, true);
ptr_loads = ptr_loads_copy;
llvm::Value *key = tmp;

set_dict_api(dict_type);
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type,
if (x.m_default) {
llvm::Type *val_type = get_type_from_ttype_t_util(dict_type->m_value_type);
llvm::Value *def_value_ptr = builder->CreateAlloca(val_type, nullptr);
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_value_type);
this->visit_expr_wrapper(x.m_default, true);
ptr_loads = ptr_loads_copy;
builder->CreateStore(tmp, def_value_ptr);
set_dict_api(dict_type);
tmp = llvm_utils->dict_api->get_item(pdict, key, *module, dict_type, def_value_ptr,
LLVM::is_llvm_struct(dict_type->m_value_type));
} else {
set_dict_api(dict_type);
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type,
compiler_options.enable_bounds_checking,
LLVM::is_llvm_struct(dict_type->m_value_type));
}
}

void visit_DictPop(const ASR::DictPop_t& x) {
Expand Down Expand Up @@ -5140,7 +5124,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>

void visit_If(const ASR::If_t &x) {
this->visit_expr_wrapper(x.m_test, true);
create_if_else(tmp, [=]() {
llvm_utils->create_if_else(tmp, [=]() {
for (size_t i=0; i<x.n_body; i++) {
this->visit_stmt(*x.m_body[i]);
}
Expand All @@ -5157,7 +5141,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Value *cond = tmp;
llvm::Value *then_val = nullptr;
llvm::Value *else_val = nullptr;
create_if_else(cond, [=, &then_val]() {
llvm_utils->create_if_else(cond, [=, &then_val]() {
this->visit_expr_wrapper(x.m_body, true);
then_val = tmp;
}, [=, &else_val]() {
Expand Down Expand Up @@ -5313,7 +5297,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
switch (x.m_op) {
case ASR::logicalbinopType::And: {
create_if_else(cond, [&, result, left_val]() {
llvm_utils->create_if_else(cond, [&, result, left_val]() {
LLVM::CreateStore(*builder, left_val, result);
}, [&, result, right_val]() {
LLVM::CreateStore(*builder, right_val, result);
Expand All @@ -5322,7 +5306,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
break;
};
case ASR::logicalbinopType::Or: {
create_if_else(cond, [&, result, right_val]() {
llvm_utils->create_if_else(cond, [&, result, right_val]() {
LLVM::CreateStore(*builder, right_val, result);

}, [&, result, left_val]() {
Expand Down Expand Up @@ -5854,7 +5838,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
void visit_Assert(const ASR::Assert_t &x) {
if (compiler_options.emit_debug_info) debug_emit_loc(x);
this->visit_expr_wrapper(x.m_test, true);
create_if_else(tmp, []() {}, [=]() {
llvm_utils->create_if_else(tmp, []() {}, [=]() {
if (compiler_options.emit_debug_info) {
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr(infile);
llvm::Value *fmt_ptr1 = llvm::ConstantInt::get(context, llvm::APInt(
Expand Down
28 changes: 26 additions & 2 deletions src/libasr/codegen/c_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1126,8 +1126,11 @@ class CCPPDSUtils {
return typecodeToDSfuncs[dict_type_code]["dict_insert"];
}

std::string get_dict_get_func(ASR::Dict_t* d_type) {
std::string get_dict_get_func(ASR::Dict_t* d_type, bool with_fallback=false) {
std::string dict_type_code = ASRUtils::get_type_code((ASR::ttype_t*)d_type, true);
if (with_fallback) {
return typecodeToDSfuncs[dict_type_code]["dict_get_fb"];
}
return typecodeToDSfuncs[dict_type_code]["dict_get"];
}

Expand Down Expand Up @@ -1177,6 +1180,7 @@ class CCPPDSUtils {
dict_resize(dict_type, dict_struct_type, dict_type_code);
dict_insert(dict_type, dict_struct_type, dict_type_code);
dict_get_item(dict_type, dict_struct_type, dict_type_code);
dict_get_item_with_fallback(dict_type, dict_struct_type, dict_type_code);
dict_len(dict_type, dict_struct_type, dict_type_code);
dict_pop(dict_type, dict_struct_type, dict_type_code);
dict_deepcopy(dict_type, dict_struct_type, dict_type_code);
Expand Down Expand Up @@ -1294,11 +1298,31 @@ class CCPPDSUtils {
generated_code += indent + tab + "int j=k\%x->capacity, c = 0;\n";
generated_code += indent + tab + "while(c<x->capacity && x->present[j] && !(x->key[j] == k)) j=(j+1)\%x->capacity, c++;\n";
generated_code += indent + tab + "if (x->present[j] && x->key[j] == k) return x->value[j];\n";
generated_code += indent + tab + "printf(\"Key not found\");\n";
generated_code += indent + tab + "printf(\"Key not found\\n\");\n";
generated_code += indent + tab + "exit(1);\n";
generated_code += indent + "}\n\n";
}

void dict_get_item_with_fallback(ASR::Dict_t *dict_type, std::string dict_struct_type,
std::string dict_type_code) {
std::string indent(indentation_level * indentation_spaces, ' ');
std::string tab(indentation_spaces, ' ');
std::string dict_get_func = global_scope->get_unique_name("dict_get_item_fb_" + dict_type_code);
typecodeToDSfuncs[dict_type_code]["dict_get_fb"] = dict_get_func;
std::string key = CUtils::get_c_type_from_ttype_t(dict_type->m_key_type);
std::string val = CUtils::get_c_type_from_ttype_t(dict_type->m_value_type);
std::string signature = val + " " + dict_get_func + "(" + dict_struct_type + "* x, " +\
key + " k, " + val + " dv)";
func_decls += indent + "inline " + signature + ";\n";
signature = indent + signature;
generated_code += indent + signature + " {\n";
generated_code += indent + tab + "int j=k\%x->capacity, c = 0;\n";
generated_code += indent + tab + "while(c<x->capacity && x->present[j] && !(x->key[j] == k)) j=(j+1)\%x->capacity, c++;\n";
generated_code += indent + tab + "if (x->present[j] && x->key[j] == k) return x->value[j];\n";
generated_code += indent + tab + "return dv;\n";
generated_code += indent + "}\n\n";
}

void dict_len(ASR::Dict_t *dict_type, std::string dict_struct_type,
std::string dict_type_code) {
std::string indent(indentation_level * indentation_spaces, ' ');
Expand Down
Loading