Skip to content

Commit 61443f7

Browse files
committed
Use dict_api to implement DictConstant, Assignment of Dicts and DictInsert
1 parent 3bb54d3 commit 61443f7

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
221221
std::unique_ptr<LLVMUtils> llvm_utils;
222222
std::unique_ptr<LLVMList> list_api;
223223
std::unique_ptr<LLVMTuple> tuple_api;
224+
std::unique_ptr<LLVMDict> dict_api;
224225
std::unique_ptr<LLVMArrUtils::Descriptor> arr_descr;
225226

226227
uint64_t ptr_loads;
@@ -237,6 +238,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
237238
llvm_utils(std::make_unique<LLVMUtils>(context, builder.get())),
238239
list_api(std::make_unique<LLVMList>(context, llvm_utils.get(), builder.get())),
239240
tuple_api(std::make_unique<LLVMTuple>(context, llvm_utils.get(), builder.get())),
241+
dict_api(std::make_unique<LLVMDict>(context, llvm_utils.get(), builder.get())),
240242
arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context,
241243
builder.get(),
242244
llvm_utils.get(),
@@ -246,6 +248,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
246248
{
247249
llvm_utils->tuple_api = tuple_api.get();
248250
llvm_utils->list_api = list_api.get();
251+
llvm_utils->dict_api = dict_api.get();
249252
}
250253

251254
llvm::Value* CreateLoad(llvm::Value *x) {
@@ -1155,6 +1158,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
11551158
tmp = const_list;
11561159
}
11571160

1161+
void visit_DictConstant(const ASR::DictConstant_t& x) {
1162+
llvm::Type* const_dict_type = get_dict_type(x.m_type);
1163+
llvm::Value* const_dict = builder->CreateAlloca(const_dict_type, nullptr, "const_dict");
1164+
ASR::Dict_t* x_dict = ASR::down_cast<ASR::Dict_t>(x.m_type);
1165+
std::string key_type_code = ASRUtils::get_type_code(x_dict->m_key_type);
1166+
std::string value_type_code = ASRUtils::get_type_code(x_dict->m_value_type);
1167+
dict_api->dict_init(key_type_code, value_type_code, const_dict, module.get(), x.n_keys);
1168+
uint64_t ptr_loads_key = LLVM::is_llvm_struct(x_dict->m_key_type) ? 0 : 2;
1169+
uint64_t ptr_loads_value = LLVM::is_llvm_struct(x_dict->m_value_type) ? 0 : 2;
1170+
uint64_t ptr_loads_copy = ptr_loads;
1171+
for( size_t i = 0; i < x.n_keys; i++ ) {
1172+
ptr_loads = ptr_loads_key;
1173+
visit_expr(*x.m_keys[i]);
1174+
llvm::Value* key = tmp;
1175+
ptr_loads = ptr_loads_value;
1176+
visit_expr(*x.m_values[i]);
1177+
llvm::Value* value = tmp;
1178+
dict_api->write_item(const_dict, key, value, module.get(),
1179+
x_dict->m_key_type, x_dict->m_value_type);
1180+
}
1181+
ptr_loads = ptr_loads_copy;
1182+
tmp = const_dict;
1183+
}
1184+
11581185
void visit_TupleConstant(const ASR::TupleConstant_t& x) {
11591186
ASR::Tuple_t* tuple_type = ASR::down_cast<ASR::Tuple_t>(x.m_type);
11601187
std::string type_code = ASRUtils::get_type_code(tuple_type->m_type,
@@ -1235,6 +1262,23 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
12351262
tmp = list_api->read_item(plist, pos, LLVM::is_llvm_struct(el_type));
12361263
}
12371264

1265+
void visit_DictItem(const ASR::DictItem_t& x) {
1266+
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
1267+
ASRUtils::expr_type(x.m_a));
1268+
uint64_t ptr_loads_copy = ptr_loads;
1269+
ptr_loads = 0;
1270+
this->visit_expr(*x.m_a);
1271+
llvm::Value* pdict = tmp;
1272+
1273+
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type);
1274+
this->visit_expr_wrapper(x.m_key, true);
1275+
ptr_loads = ptr_loads_copy;
1276+
llvm::Value *key = tmp;
1277+
1278+
tmp = dict_api->read_item(pdict, key, *module, dict_type->m_key_type,
1279+
LLVM::is_llvm_struct(dict_type->m_value_type));
1280+
}
1281+
12381282
void visit_ListLen(const ASR::ListLen_t& x) {
12391283
if (x.m_value) {
12401284
this->visit_expr(*x.m_value);
@@ -1248,6 +1292,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
12481292
}
12491293
}
12501294

1295+
void visit_DictLen(const ASR::DictLen_t& x) {
1296+
if (x.m_value) {
1297+
this->visit_expr(*x.m_value);
1298+
return ;
1299+
}
1300+
1301+
uint64_t ptr_loads_copy = ptr_loads;
1302+
ptr_loads = 0;
1303+
this->visit_expr(*x.m_arg);
1304+
ptr_loads = ptr_loads_copy;
1305+
llvm::Value* pdict = tmp;
1306+
tmp = dict_api->len(pdict);
1307+
}
1308+
12511309
void visit_ListInsert(const ASR::ListInsert_t& x) {
12521310
ASR::List_t* asr_list = ASR::down_cast<ASR::List_t>(
12531311
ASRUtils::expr_type(x.m_a));
@@ -1268,6 +1326,26 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
12681326
list_api->insert_item(plist, pos, item, asr_list->m_type, *module);
12691327
}
12701328

1329+
void visit_DictInsert(const ASR::DictInsert_t& x) {
1330+
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
1331+
ASRUtils::expr_type(x.m_a));
1332+
uint64_t ptr_loads_copy = ptr_loads;
1333+
ptr_loads = 0;
1334+
this->visit_expr(*x.m_a);
1335+
llvm::Value* pdict = tmp;
1336+
1337+
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type);
1338+
this->visit_expr_wrapper(x.m_key, true);
1339+
llvm::Value *key = tmp;
1340+
this->visit_expr_wrapper(x.m_value, true);
1341+
llvm::Value *value = tmp;
1342+
ptr_loads = ptr_loads_copy;
1343+
1344+
dict_api->write_item(pdict, key, value, module.get(),
1345+
dict_type->m_key_type,
1346+
dict_type->m_value_type);
1347+
}
1348+
12711349
void visit_ListRemove(const ASR::ListRemove_t& x) {
12721350
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_a));
12731351
uint64_t ptr_loads_copy = ptr_loads;
@@ -1717,6 +1795,41 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17171795
return false;
17181796
}
17191797

1798+
int32_t get_type_size(ASR::ttype_t* asr_type, llvm::Type* llvm_type,
1799+
int32_t a_kind) {
1800+
if( LLVM::is_llvm_struct(asr_type) ||
1801+
ASR::is_a<ASR::Character_t>(*asr_type) ||
1802+
ASR::is_a<ASR::Complex_t>(*asr_type) ) {
1803+
llvm::DataLayout data_layout(module.get());
1804+
return data_layout.getTypeAllocSize(llvm_type);
1805+
}
1806+
return a_kind;
1807+
}
1808+
1809+
llvm::Type* get_dict_type(ASR::ttype_t* asr_type) {
1810+
ASR::Dict_t* asr_dict = ASR::down_cast<ASR::Dict_t>(asr_type);
1811+
bool is_local_array_type = false, is_local_malloc_array_type = false;
1812+
bool is_local_list = false;
1813+
ASR::dimension_t* local_m_dims = nullptr;
1814+
int local_n_dims = 0;
1815+
int local_a_kind = -1;
1816+
ASR::storage_typeType local_m_storage = ASR::storage_typeType::Default;
1817+
llvm::Type* key_llvm_type = get_type_from_ttype_t(asr_dict->m_key_type, local_m_storage,
1818+
is_local_array_type, is_local_malloc_array_type,
1819+
is_local_list, local_m_dims, local_n_dims,
1820+
local_a_kind);
1821+
int32_t key_type_size = get_type_size(asr_dict->m_key_type, key_llvm_type, local_a_kind);
1822+
llvm::Type* value_llvm_type = get_type_from_ttype_t(asr_dict->m_value_type, local_m_storage,
1823+
is_local_array_type, is_local_malloc_array_type,
1824+
is_local_list, local_m_dims, local_n_dims,
1825+
local_a_kind);
1826+
int32_t value_type_size = get_type_size(asr_dict->m_value_type, value_llvm_type, local_a_kind);
1827+
std::string key_type_code = ASRUtils::get_type_code(asr_dict->m_key_type);
1828+
std::string value_type_code = ASRUtils::get_type_code(asr_dict->m_value_type);
1829+
return dict_api->get_dict_type(key_type_code, value_type_code, key_type_size,
1830+
value_type_size, key_llvm_type, value_llvm_type);
1831+
}
1832+
17201833
llvm::Type* get_type_from_ttype_t(ASR::ttype_t* asr_type,
17211834
ASR::storage_typeType m_storage,
17221835
bool& is_array_type, bool& is_malloc_array_type,
@@ -1865,6 +1978,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
18651978
llvm_type = list_api->get_list_type(el_llvm_type, el_type_code, type_size);
18661979
break;
18671980
}
1981+
case (ASR::ttypeType::Dict): {
1982+
llvm_type = get_dict_type(asr_type);
1983+
break;
1984+
}
18681985
case (ASR::ttypeType::Tuple) : {
18691986
ASR::Tuple_t* asr_tuple = ASR::down_cast<ASR::Tuple_t>(asr_type);
18701987
std::string type_code = ASRUtils::get_type_code(asr_tuple->m_type,
@@ -2959,6 +3076,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
29593076
bool is_value_list = ASR::is_a<ASR::List_t>(*asr_value_type);
29603077
bool is_target_tuple = ASR::is_a<ASR::Tuple_t>(*asr_target_type);
29613078
bool is_value_tuple = ASR::is_a<ASR::Tuple_t>(*asr_value_type);
3079+
bool is_target_dict = ASR::is_a<ASR::Dict_t>(*asr_target_type);
3080+
bool is_value_dict = ASR::is_a<ASR::Dict_t>(*asr_value_type);
29623081
if( is_target_list && is_value_list ) {
29633082
uint64_t ptr_loads_copy = ptr_loads;
29643083
ptr_loads = 0;
@@ -3014,6 +3133,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
30143133
}
30153134
}
30163135
return ;
3136+
} else if( is_target_dict && is_value_dict ) {
3137+
uint64_t ptr_loads_copy = ptr_loads;
3138+
ptr_loads = 0;
3139+
this->visit_expr(*x.m_value);
3140+
llvm::Value* value_dict = tmp;
3141+
this->visit_expr(*x.m_target);
3142+
llvm::Value* target_dict = tmp;
3143+
ptr_loads = ptr_loads_copy;
3144+
ASR::Dict_t* value_dict_type = ASR::down_cast<ASR::Dict_t>(asr_value_type);
3145+
dict_api->dict_deepcopy(value_dict, target_dict,
3146+
value_dict_type, module.get());
3147+
return ;
30173148
}
30183149
if( ASR::is_a<ASR::Pointer_t>(*ASRUtils::expr_type(x.m_target)) &&
30193150
ASR::is_a<ASR::GetPointer_t>(*x.m_value) ) {

0 commit comments

Comments
 (0)