Skip to content

Commit 71533e3

Browse files
committed
wip
1 parent 395b7b3 commit 71533e3

File tree

5 files changed

+125
-63
lines changed

5 files changed

+125
-63
lines changed

integration_tests/test_dict_06.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from ltypes import i32, f64
2+
3+
def test_dict():
4+
graph: dict[i32, dict[i32, f64]] = {0: {2: 1.0/2.0}, 1: {3: 1.0/3.0}}
5+
i: i32; j: i32; nodes: i32; eps: f64 = 1e-12
6+
nodes = 100
7+
8+
for i in range(nodes):
9+
graph[i] = {}
10+
for j in range(nodes):
11+
graph[i][j] = float(abs(j - i))
12+
13+
for i in range(nodes):
14+
for j in range(nodes):
15+
print(graph[i][j], float(abs(j - i)))
16+
# assert abs( graph[i][j] - float(abs(j - i)) ) <= eps
17+
18+
test_dict()

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
11671167
this->visit_expr(*x.m_args[i]);
11681168
llvm::Value* item = tmp;
11691169
llvm::Value* pos = llvm::ConstantInt::get(context, llvm::APInt(32, i));
1170-
list_api->write_item(const_list, pos, item, list_type->m_type, *module);
1170+
list_api->write_item(const_list, pos, item, list_type->m_type, module.get());
11711171
}
11721172
ptr_loads = ptr_loads_copy;
11731173
tmp = const_list;
@@ -1267,7 +1267,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
12671267
llvm::Value *item = tmp;
12681268
ptr_loads = ptr_loads_copy;
12691269

1270-
list_api->append(plist, item, asr_list->m_type, *module);
1270+
list_api->append(plist, item, asr_list->m_type, module.get());
12711271
}
12721272

12731273
void visit_ListItem(const ASR::ListItem_t& x) {
@@ -1303,7 +1303,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
13031303

13041304
set_dict_api(dict_type);
13051305
tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type,
1306-
LLVM::is_llvm_struct(dict_type->m_value_type));
1306+
LLVM::is_llvm_struct(dict_type->m_value_type) ||
1307+
ptr_loads == 0);
13071308
}
13081309

13091310
void visit_DictPop(const ASR::DictPop_t& x) {
@@ -1370,7 +1371,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
13701371
llvm::Value *item = tmp;
13711372
ptr_loads = ptr_loads_copy;
13721373

1373-
list_api->insert_item(plist, pos, item, asr_list->m_type, *module);
1374+
list_api->insert_item(plist, pos, item, asr_list->m_type, module.get());
13741375
}
13751376

13761377
void visit_DictInsert(const ASR::DictInsert_t& x) {
@@ -3191,7 +3192,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
31913192
ASRUtils::expr_type(x.m_value));
31923193
std::string value_type_code = ASRUtils::get_type_code(value_asr_list->m_type);
31933194
list_api->list_deepcopy(value_list, target_list,
3194-
value_asr_list, *module);
3195+
value_asr_list, module.get());
31953196
return ;
31963197
} else if( is_target_tuple && is_value_tuple ) {
31973198
uint64_t ptr_loads_copy = ptr_loads;
@@ -3220,7 +3221,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
32203221
llvm::Value* llvm_tuple_i = builder->CreateAlloca(llvm_tuple_i_type, nullptr);
32213222
ptr_loads = !LLVM::is_llvm_struct(asr_tuple_i_type);
32223223
visit_expr(*asr_value_tuple->m_elements[i]);
3223-
llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, *module);
3224+
llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, module.get());
32243225
src_deepcopies.push_back(al, llvm_tuple_i);
32253226
}
32263227
ASR::TupleConstant_t* asr_target_tuple = ASR::down_cast<ASR::TupleConstant_t>(x.m_target);
@@ -3244,7 +3245,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
32443245
std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type,
32453246
value_tuple_type->n_type);
32463247
tuple_api->tuple_deepcopy(value_tuple, target_tuple,
3247-
value_tuple_type, *module);
3248+
value_tuple_type, module.get());
32483249
}
32493250
return ;
32503251
} else if( is_target_dict && is_value_dict ) {
@@ -3277,7 +3278,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
32773278
if( x.m_target->type == ASR::exprType::ArrayItem ||
32783279
x.m_target->type == ASR::exprType::ArraySection ||
32793280
x.m_target->type == ASR::exprType::DerivedRef ||
3280-
x.m_target->type == ASR::exprType::ListItem ) {
3281+
x.m_target->type == ASR::exprType::ListItem ||
3282+
x.m_target->type == ASR::exprType::DictItem ) {
32813283
is_assignment_target = true;
32823284
this->visit_expr(*x.m_target);
32833285
is_assignment_target = false;
@@ -3316,6 +3318,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
33163318
this->visit_expr_wrapper(asr_target0->m_pos, true);
33173319
llvm::Value* pos = tmp;
33183320
target = list_api->read_item(list, pos, true);
3321+
} else if( ASR::is_a<ASR::DictItem_t>(*x.m_target) ) {
3322+
uint64_t ptr_loads_copy = ptr_loads;
3323+
ptr_loads = 0;
3324+
visit_expr(*x.m_target);
3325+
ptr_loads = ptr_loads_copy;
3326+
target = tmp;
33193327
}
33203328
} else {
33213329
ASR::Variable_t *asr_target = EXPR2VAR(x.m_target);
@@ -5295,7 +5303,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
52955303
}
52965304
if( ASR::is_a<ASR::Tuple_t>(*arg_type) ||
52975305
ASR::is_a<ASR::List_t>(*arg_type) ) {
5298-
llvm_utils->deepcopy(value, target, arg_type, *module);
5306+
llvm_utils->deepcopy(value, target, arg_type, module.get());
52995307
} else {
53005308
builder->CreateStore(value, target);
53015309
}

src/libasr/codegen/llvm_utils.cpp

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ namespace LFortran {
298298
}
299299

300300
void LLVMUtils::deepcopy(llvm::Value* src, llvm::Value* dest,
301-
ASR::ttype_t* asr_type, llvm::Module& module) {
301+
ASR::ttype_t* asr_type, llvm::Module* module) {
302302
switch( asr_type->type ) {
303303
case ASR::ttypeType::Integer:
304304
case ASR::ttypeType::Real:
@@ -318,6 +318,11 @@ namespace LFortran {
318318
list_api->list_deepcopy(src, dest, list_type, module);
319319
break ;
320320
}
321+
case ASR::ttypeType::Dict: {
322+
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(asr_type);
323+
dict_api->dict_deepcopy(src, dest, dict_type, module);
324+
break;
325+
}
321326
default: {
322327
throw LCompilersException("LLVMUtils::deepcopy isn't implemented for " +
323328
ASRUtils::type_to_str_python(asr_type));
@@ -611,12 +616,12 @@ namespace LFortran {
611616
}
612617

613618
void LLVMList::list_deepcopy(llvm::Value* src, llvm::Value* dest,
614-
ASR::List_t* list_type, llvm::Module& module) {
619+
ASR::List_t* list_type, llvm::Module* module) {
615620
list_deepcopy(src, dest, list_type->m_type, module);
616621
}
617622

618623
void LLVMList::list_deepcopy(llvm::Value* src, llvm::Value* dest,
619-
ASR::ttype_t* element_type, llvm::Module& module) {
624+
ASR::ttype_t* element_type, llvm::Module* module) {
620625
LFORTRAN_ASSERT(src->getType() == dest->getType());
621626
std::string src_type_code = ASRUtils::get_type_code(element_type);
622627
llvm::Value* src_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(src));
@@ -629,7 +634,7 @@ namespace LFortran {
629634
int32_t type_size = std::get<1>(typecode2listtype[src_type_code]);
630635
llvm::Value* arg_size = builder->CreateMul(llvm::ConstantInt::get(context,
631636
llvm::APInt(32, type_size)), src_capacity);
632-
llvm::Value* copy_data = LLVM::lfortran_malloc(context, module, *builder,
637+
llvm::Value* copy_data = LLVM::lfortran_malloc(context, *module, *builder,
633638
arg_size);
634639
llvm::Type* el_type = std::get<2>(typecode2listtype[src_type_code]);
635640
copy_data = builder->CreateBitCast(copy_data, el_type->getPointerTo());
@@ -700,12 +705,12 @@ namespace LFortran {
700705
llvm::Value* src_key_list = get_key_list(src);
701706
llvm::Value* dest_key_list = get_key_list(dest);
702707
llvm_utils->list_api->list_deepcopy(src_key_list, dest_key_list,
703-
dict_type->m_key_type, *module);
708+
dict_type->m_key_type, module);
704709

705710
llvm::Value* src_value_list = get_value_list(src);
706711
llvm::Value* dest_value_list = get_value_list(dest);
707712
llvm_utils->list_api->list_deepcopy(src_value_list, dest_value_list,
708-
dict_type->m_value_type, *module);
713+
dict_type->m_value_type, module);
709714

710715
llvm::Value* src_key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(src));
711716
llvm::Value* dest_key_mask_ptr = get_pointer_to_keymask(dest);
@@ -768,8 +773,8 @@ namespace LFortran {
768773
}
769774
llvm::Value* dest_key_ptr = llvm_utils->create_gep(curr_dest, 0);
770775
llvm::Value* dest_value_ptr = llvm_utils->create_gep(curr_dest, 1);
771-
llvm_utils->deepcopy(src_key, dest_key_ptr, dict_type->m_key_type, *module);
772-
llvm_utils->deepcopy(src_value, dest_value_ptr, dict_type->m_value_type, *module);
776+
llvm_utils->deepcopy(src_key, dest_key_ptr, dict_type->m_key_type, module);
777+
llvm_utils->deepcopy(src_value, dest_value_ptr, dict_type->m_value_type, module);
773778

774779
llvm::Value* src_next_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(curr_src, 2));
775780
llvm::Value* curr_dest_next_ptr = llvm_utils->create_gep(curr_dest, 2);
@@ -957,7 +962,7 @@ namespace LFortran {
957962

958963
void LLVMList::write_item(llvm::Value* list, llvm::Value* pos,
959964
llvm::Value* item, ASR::ttype_t* asr_type,
960-
llvm::Module& module, bool check_index_bound) {
965+
llvm::Module* module, bool check_index_bound) {
961966
if( check_index_bound ) {
962967
check_index_within_bounds(list, pos);
963968
}
@@ -1323,9 +1328,9 @@ namespace LFortran {
13231328
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, *module, key_asr_type);
13241329
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
13251330
llvm_utils->list_api->write_item(key_list, pos, key,
1326-
key_asr_type, *module, false);
1331+
key_asr_type, module, false);
13271332
llvm_utils->list_api->write_item(value_list, pos, value,
1328-
value_asr_type, *module, false);
1333+
value_asr_type, module, false);
13291334
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
13301335
llvm_utils->create_ptr_gep(key_mask, pos));
13311336
llvm::Value* is_slot_empty = builder->CreateICmpEQ(key_mask_value,
@@ -1352,9 +1357,9 @@ namespace LFortran {
13521357
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, *module, key_asr_type);
13531358
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
13541359
llvm_utils->list_api->write_item(key_list, pos, key,
1355-
key_asr_type, *module, false);
1360+
key_asr_type, module, false);
13561361
llvm_utils->list_api->write_item(value_list, pos, value,
1357-
value_asr_type, *module, false);
1362+
value_asr_type, module, false);
13581363

13591364
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
13601365
llvm_utils->create_ptr_gep(key_mask, pos));
@@ -1407,8 +1412,8 @@ namespace LFortran {
14071412
llvm::Value* malloc_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), kv_struct_size);
14081413
llvm::Value* new_kv_struct_i8 = LLVM::lfortran_malloc(context, *module, *builder, malloc_size);
14091414
llvm::Value* new_kv_struct = builder->CreateBitCast(new_kv_struct_i8, kv_struct_type->getPointerTo());
1410-
llvm_utils->deepcopy(key, llvm_utils->create_gep(new_kv_struct, 0), key_asr_type, *module);
1411-
llvm_utils->deepcopy(value, llvm_utils->create_gep(new_kv_struct, 1), value_asr_type, *module);
1415+
llvm_utils->deepcopy(key, llvm_utils->create_gep(new_kv_struct, 0), key_asr_type, module);
1416+
llvm_utils->deepcopy(value, llvm_utils->create_gep(new_kv_struct, 1), value_asr_type, module);
14121417
LLVM::CreateStore(*builder,
14131418
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)),
14141419
llvm_utils->create_gep(new_kv_struct, 2));
@@ -1420,8 +1425,8 @@ namespace LFortran {
14201425
llvm_utils->start_new_block(elseBB);
14211426
{
14221427
llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo());
1423-
llvm_utils->deepcopy(key, llvm_utils->create_gep(kv_struct, 0), key_asr_type, *module);
1424-
llvm_utils->deepcopy(value, llvm_utils->create_gep(kv_struct, 1), value_asr_type, *module);
1428+
llvm_utils->deepcopy(key, llvm_utils->create_gep(kv_struct, 0), key_asr_type, module);
1429+
llvm_utils->deepcopy(value, llvm_utils->create_gep(kv_struct, 1), value_asr_type, module);
14251430
}
14261431
llvm_utils->start_new_block(mergeBB);
14271432
llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict);
@@ -1766,10 +1771,10 @@ namespace LFortran {
17661771
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
17671772
llvm::Value* key_dest = llvm_utils->list_api->read_item(new_key_list, pos,
17681773
true, false);
1769-
llvm_utils->deepcopy(key, key_dest, key_asr_type, *module);
1774+
llvm_utils->deepcopy(key, key_dest, key_asr_type, module);
17701775
llvm::Value* value_dest = llvm_utils->list_api->read_item(new_value_list, pos,
17711776
true, false);
1772-
llvm_utils->deepcopy(value, value_dest, value_asr_type, *module);
1777+
llvm_utils->deepcopy(value, value_dest, value_asr_type, module);
17731778

17741779
llvm::Value* linear_prob_happened = builder->CreateICmpNE(key_hash, pos);
17751780
llvm::Value* set_max_2 = builder->CreateSelect(linear_prob_happened,
@@ -2225,21 +2230,21 @@ namespace LFortran {
22252230
}
22262231

22272232
void LLVMList::append(llvm::Value* list, llvm::Value* item,
2228-
ASR::ttype_t* asr_type, llvm::Module& module) {
2233+
ASR::ttype_t* asr_type, llvm::Module* module) {
22292234
llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list));
22302235
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_current_capacity(list));
22312236
std::string type_code = ASRUtils::get_type_code(asr_type);
22322237
int type_size = std::get<1>(typecode2listtype[type_code]);
22332238
llvm::Type* el_type = std::get<2>(typecode2listtype[type_code]);
22342239
resize_if_needed(list, current_end_point, current_capacity,
2235-
type_size, el_type, module);
2240+
type_size, el_type, *module);
22362241
write_item(list, current_end_point, item, asr_type, module);
22372242
shift_end_point_by_one(list);
22382243
}
22392244

22402245
void LLVMList::insert_item(llvm::Value* list, llvm::Value* pos,
22412246
llvm::Value* item, ASR::ttype_t* asr_type,
2242-
llvm::Module& module) {
2247+
llvm::Module* module) {
22432248
std::string type_code = ASRUtils::get_type_code(asr_type);
22442249
llvm::Value* current_end_point = LLVM::CreateLoad(*builder,
22452250
get_pointer_to_current_end_point(list));
@@ -2248,7 +2253,7 @@ namespace LFortran {
22482253
int type_size = std::get<1>(typecode2listtype[type_code]);
22492254
llvm::Type* el_type = std::get<2>(typecode2listtype[type_code]);
22502255
resize_if_needed(list, current_end_point, current_capacity,
2251-
type_size, el_type, module);
2256+
type_size, el_type, *module);
22522257

22532258
/* While loop equivalent in C++:
22542259
* end_point // nth index of list
@@ -2514,7 +2519,7 @@ namespace LFortran {
25142519
}
25152520

25162521
void LLVMTuple::tuple_deepcopy(llvm::Value* src, llvm::Value* dest,
2517-
ASR::Tuple_t* tuple_type, llvm::Module& module) {
2522+
ASR::Tuple_t* tuple_type, llvm::Module* module) {
25182523
LFORTRAN_ASSERT(src->getType() == dest->getType());
25192524
for( size_t i = 0; i < tuple_type->n_type; i++ ) {
25202525
llvm::Value* src_item = read_item(src, i, LLVM::is_llvm_struct(

src/libasr/codegen/llvm_utils.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ namespace LFortran {
5757
return ASR::is_a<ASR::Tuple_t>(*asr_type) ||
5858
ASR::is_a<ASR::List_t>(*asr_type) ||
5959
ASR::is_a<ASR::Derived_t>(*asr_type) ||
60-
ASR::is_a<ASR::Class_t>(*asr_type);
60+
ASR::is_a<ASR::Class_t>(*asr_type) ||
61+
ASR::is_a<ASR::Dict_t>(*asr_type);
6162
}
6263
}
6364

@@ -107,7 +108,7 @@ namespace LFortran {
107108
void reset_iterators();
108109

109110
void deepcopy(llvm::Value* src, llvm::Value* dest,
110-
ASR::ttype_t* asr_type, llvm::Module& module);
111+
ASR::ttype_t* asr_type, llvm::Module* module);
111112

112113
}; // LLVMUtils
113114

@@ -150,11 +151,11 @@ namespace LFortran {
150151

151152
void list_deepcopy(llvm::Value* src, llvm::Value* dest,
152153
ASR::List_t* list_type,
153-
llvm::Module& module);
154+
llvm::Module* module);
154155

155156
void list_deepcopy(llvm::Value* src, llvm::Value* dest,
156157
ASR::ttype_t* element_type,
157-
llvm::Module& module);
158+
llvm::Module* module);
158159

159160
llvm::Value* read_item(llvm::Value* list, llvm::Value* pos,
160161
bool get_pointer=false, bool check_index_bound=true);
@@ -165,17 +166,17 @@ namespace LFortran {
165166

166167
void write_item(llvm::Value* list, llvm::Value* pos,
167168
llvm::Value* item, ASR::ttype_t* asr_type,
168-
llvm::Module& module, bool check_index_bound=true);
169+
llvm::Module* module, bool check_index_bound=true);
169170

170171
void write_item(llvm::Value* list, llvm::Value* pos,
171172
llvm::Value* item, bool check_index_bound=true);
172173

173174
void append(llvm::Value* list, llvm::Value* item,
174-
ASR::ttype_t* asr_type, llvm::Module& module);
175+
ASR::ttype_t* asr_type, llvm::Module* module);
175176

176177
void insert_item(llvm::Value* list, llvm::Value* pos,
177178
llvm::Value* item, ASR::ttype_t* asr_type,
178-
llvm::Module& module);
179+
llvm::Module* module);
179180

180181
void remove(llvm::Value* list, llvm::Value* item,
181182
ASR::ttype_t* item_type, llvm::Module& module);
@@ -216,7 +217,7 @@ namespace LFortran {
216217
bool get_pointer=false);
217218

218219
void tuple_deepcopy(llvm::Value* src, llvm::Value* dest,
219-
ASR::Tuple_t* type_code, llvm::Module& module);
220+
ASR::Tuple_t* type_code, llvm::Module* module);
220221

221222
llvm::Value* check_tuple_equality(llvm::Value* t1, llvm::Value* t2,
222223
ASR::Tuple_t* tuple_type, llvm::LLVMContext& context,

0 commit comments

Comments
 (0)