diff --git a/integration_tests/test_dict_06.py b/integration_tests/test_dict_06.py new file mode 100644 index 0000000000..c1a4b4c817 --- /dev/null +++ b/integration_tests/test_dict_06.py @@ -0,0 +1,25 @@ +from ltypes import i32, f64 + +def test_dict(): + graph: dict[i32, dict[i32, f64]] = {0: {2: 1.0/2.0}, 1: {3: 1.0/4.0}} + i: i32; j: i32; nodes: i32; eps: f64 = 1e-12 + nodes = 100 + + assert abs(graph[0][2] - 0.5) <= eps + assert abs(graph[1][3] - 0.25) <= eps + + for i in range(1, nodes): + graph[i] = {} + for j in range(1, nodes): + graph[i][j] = 1.0/float(j + i) + + for i in range(1, nodes): + for j in range(1, nodes): + assert abs( graph[i][j] - 1.0/float(j + i) ) <= eps + + for i in range(1, nodes): + for j in range(1, nodes): + assert abs( graph[i].pop(j) - 1.0/float(j + i) ) <= eps + graph.pop(i) + +test_dict() diff --git a/integration_tests/test_dict_07.py b/integration_tests/test_dict_07.py new file mode 100644 index 0000000000..d225dbd461 --- /dev/null +++ b/integration_tests/test_dict_07.py @@ -0,0 +1,32 @@ +from ltypes import i32, f64 +from math import sin + +def test_dict(): + friends: dict[tuple[i32, str], dict[tuple[i32, str], tuple[f64, str]]] = {} + eps: f64 = 1e-12 + nodes: i32 = 100 + i: i32; j: i32; + t: tuple[f64, str] + + friends[(0, "0")] = {(1, "1"): (sin(-1.0), "01")} + assert friends[(0, "0")][(1, "1")][0] - sin(-1.0) <= eps + + for i in range(nodes): + friends[(i, str(i))] = {} + for j in range(i, nodes): + friends[(i, str(i))][(j, str(j))] = (sin(float(i - j)), str(i) + str(j)) + + for i in range(nodes): + for j in range(i, nodes): + t = friends[(i, str(i))][(j, str(j))] + assert abs( t[0] - sin(float(i - j)) ) <= eps + assert t[1] == str(i) + str(j) + + for i in range(nodes): + for j in range(i, nodes): + t = friends[(i, str(i))].pop((j, str(j))) + assert abs( t[0] - sin(float(i - j)) ) <= eps + assert t[1] == str(i) + str(j) + friends.pop((i, str(i))) + +test_dict() diff --git a/integration_tests/test_dict_08.py b/integration_tests/test_dict_08.py new file mode 100644 index 0000000000..f0708d3294 --- /dev/null +++ b/integration_tests/test_dict_08.py @@ -0,0 +1,32 @@ +from ltypes import i32, f64 +from math import sin + +def test_dict(): + friends: dict[str, dict[i32, dict[str, f64]]] = {} + eps: f64 = 1e-12 + nodes: i32 = 9 + i: i32; j: i32; k: i32; + + friends["0"] = {1: {"2": sin(3.0)}} + assert friends["0"][1]["2"] - sin(3.0) <= eps + + for i in range(nodes): + friends[str(i)] = {} + for j in range(nodes): + friends[str(i)][j] = {} + for k in range(nodes): + friends[str(i)][j][str(k)] = sin(float(i + j + k)) + + for i in range(nodes): + for j in range(nodes): + for k in range(nodes): + abs( friends[str(i)][j][str(k)] - sin(float(i + j + k)) ) <= eps + + for i in range(nodes): + for j in range(nodes): + for k in range(nodes): + abs( friends[str(i)][j].pop(str(k)) - sin(float(i + j + k)) ) <= eps + friends[str(i)].pop(j) + friends.pop(str(i)) + +test_dict() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index a3206a9a96..1cd3efb804 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1167,18 +1167,22 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*x.m_args[i]); llvm::Value* item = tmp; llvm::Value* pos = llvm::ConstantInt::get(context, llvm::APInt(32, i)); - list_api->write_item(const_list, pos, item, list_type->m_type, *module); + list_api->write_item(const_list, pos, item, list_type->m_type, module.get()); } ptr_loads = ptr_loads_copy; tmp = const_list; } void set_dict_api(ASR::Dict_t* dict_type) { - if( ASR::is_a(*dict_type->m_key_type) ) { - llvm_utils->dict_api = dict_api_sc.get(); - } else { - llvm_utils->dict_api = dict_api_lp.get(); - } + // if( llvm_utils->dict_api != nullptr ) { + // return ; + // } + // if( ASR::is_a(*dict_type->m_key_type) ) { + // llvm_utils->dict_api = dict_api_sc.get(); + // } else { + // llvm_utils->dict_api = dict_api_lp.get(); + // } + llvm_utils->dict_api = dict_api_lp.get(); } void visit_DictConstant(const ASR::DictConstant_t& x) { @@ -1267,7 +1271,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value *item = tmp; ptr_loads = ptr_loads_copy; - list_api->append(plist, item, asr_list->m_type, *module); + list_api->append(plist, item, asr_list->m_type, module.get()); } void visit_ListItem(const ASR::ListItem_t& x) { @@ -1303,7 +1307,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor set_dict_api(dict_type); tmp = llvm_utils->dict_api->read_item(pdict, key, *module, dict_type, - LLVM::is_llvm_struct(dict_type->m_value_type)); + LLVM::is_llvm_struct(dict_type->m_value_type) || + is_assignment_target); } void visit_DictPop(const ASR::DictPop_t& x) { @@ -1370,7 +1375,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value *item = tmp; ptr_loads = ptr_loads_copy; - list_api->insert_item(plist, pos, item, asr_list->m_type, *module); + list_api->insert_item(plist, pos, item, asr_list->m_type, module.get()); } void visit_DictInsert(const ASR::DictInsert_t& x) { @@ -1864,8 +1869,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor return a_kind; } - llvm::Type* get_dict_type(ASR::ttype_t* asr_type) { + llvm::Type* get_dict_type(ASR::ttype_t* asr_type, bool is_nested_call=false) { ASR::Dict_t* asr_dict = ASR::down_cast(asr_type); + if( !is_nested_call ) { + set_dict_api(asr_dict); + } bool is_local_array_type = false, is_local_malloc_array_type = false; bool is_local_list = false; ASR::dimension_t* local_m_dims = nullptr; @@ -1873,18 +1881,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor int local_a_kind = -1; ASR::storage_typeType local_m_storage = ASR::storage_typeType::Default; llvm::Type* key_llvm_type = get_type_from_ttype_t(asr_dict->m_key_type, local_m_storage, - is_local_array_type, is_local_malloc_array_type, - is_local_list, local_m_dims, local_n_dims, - local_a_kind); + is_local_array_type, is_local_malloc_array_type, + is_local_list, local_m_dims, local_n_dims, + local_a_kind, true); int32_t key_type_size = get_type_size(asr_dict->m_key_type, key_llvm_type, local_a_kind); llvm::Type* value_llvm_type = get_type_from_ttype_t(asr_dict->m_value_type, local_m_storage, is_local_array_type, is_local_malloc_array_type, is_local_list, local_m_dims, local_n_dims, - local_a_kind); + local_a_kind, true); int32_t value_type_size = get_type_size(asr_dict->m_value_type, value_llvm_type, local_a_kind); std::string key_type_code = ASRUtils::get_type_code(asr_dict->m_key_type); std::string value_type_code = ASRUtils::get_type_code(asr_dict->m_value_type); - set_dict_api(asr_dict); return llvm_utils->dict_api->get_dict_type(key_type_code, value_type_code, key_type_size, value_type_size, key_llvm_type, value_llvm_type); } @@ -1893,7 +1900,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::storage_typeType m_storage, bool& is_array_type, bool& is_malloc_array_type, bool& is_list, ASR::dimension_t*& m_dims, - int& n_dims, int& a_kind) { + int& n_dims, int& a_kind, bool is_nested_call=false) { llvm::Type* llvm_type = nullptr; switch (asr_type->type) { case (ASR::ttypeType::Integer) : { @@ -2013,7 +2020,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::ttype_t *t2 = ASR::down_cast(asr_type)->m_type; llvm_type = get_type_from_ttype_t(t2, m_storage, is_array_type, is_malloc_array_type, is_list, m_dims, - n_dims, a_kind); + n_dims, a_kind, is_nested_call); llvm_type = llvm_type->getPointerTo(); break; } @@ -2023,7 +2030,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Type* el_llvm_type = get_type_from_ttype_t(asr_list->m_type, m_storage, is_array_type, is_malloc_array_type, is_list, m_dims, n_dims, - a_kind); + a_kind, is_nested_call); std::string el_type_code = ASRUtils::get_type_code(asr_list->m_type); int32_t type_size = -1; if( LLVM::is_llvm_struct(asr_list->m_type) || @@ -2038,7 +2045,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor break; } case (ASR::ttypeType::Dict): { - llvm_type = get_dict_type(asr_type); + llvm_type = get_dict_type(asr_type, is_nested_call); break; } case (ASR::ttypeType::Tuple) : { @@ -2055,7 +2062,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::storage_typeType local_m_storage = ASR::storage_typeType::Default; llvm_el_types.push_back(get_type_from_ttype_t(asr_tuple->m_type[i], local_m_storage, is_local_array_type, is_local_malloc_array_type, - is_local_list, local_m_dims, local_n_dims, local_a_kind)); + is_local_list, local_m_dims, local_n_dims, local_a_kind, + is_nested_call)); } llvm_type = tuple_api->get_tuple_type(type_code, llvm_el_types); break; @@ -3191,7 +3199,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASRUtils::expr_type(x.m_value)); std::string value_type_code = ASRUtils::get_type_code(value_asr_list->m_type); list_api->list_deepcopy(value_list, target_list, - value_asr_list, *module); + value_asr_list, module.get()); return ; } else if( is_target_tuple && is_value_tuple ) { uint64_t ptr_loads_copy = ptr_loads; @@ -3220,7 +3228,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value* llvm_tuple_i = builder->CreateAlloca(llvm_tuple_i_type, nullptr); ptr_loads = !LLVM::is_llvm_struct(asr_tuple_i_type); visit_expr(*asr_value_tuple->m_elements[i]); - llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, *module); + llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, module.get()); src_deepcopies.push_back(al, llvm_tuple_i); } ASR::TupleConstant_t* asr_target_tuple = ASR::down_cast(x.m_target); @@ -3244,7 +3252,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type, value_tuple_type->n_type); tuple_api->tuple_deepcopy(value_tuple, target_tuple, - value_tuple_type, *module); + value_tuple_type, module.get()); } return ; } else if( is_target_dict && is_value_dict ) { @@ -3277,7 +3285,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor if( x.m_target->type == ASR::exprType::ArrayItem || x.m_target->type == ASR::exprType::ArraySection || x.m_target->type == ASR::exprType::DerivedRef || - x.m_target->type == ASR::exprType::ListItem ) { + x.m_target->type == ASR::exprType::ListItem || + x.m_target->type == ASR::exprType::DictItem ) { is_assignment_target = true; this->visit_expr(*x.m_target); is_assignment_target = false; @@ -5295,7 +5304,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } if( ASR::is_a(*arg_type) || ASR::is_a(*arg_type) ) { - llvm_utils->deepcopy(value, target, arg_type, *module); + llvm_utils->deepcopy(value, target, arg_type, module.get()); } else { builder->CreateStore(value, target); } diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 3e5c3c2841..509544d0a0 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -298,7 +298,7 @@ namespace LFortran { } void LLVMUtils::deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::ttype_t* asr_type, llvm::Module& module) { + ASR::ttype_t* asr_type, llvm::Module* module) { switch( asr_type->type ) { case ASR::ttypeType::Integer: case ASR::ttypeType::Real: @@ -318,6 +318,11 @@ namespace LFortran { list_api->list_deepcopy(src, dest, list_type, module); break ; } + case ASR::ttypeType::Dict: { + ASR::Dict_t* dict_type = ASR::down_cast(asr_type); + dict_api->dict_deepcopy(src, dest, dict_type, module); + break; + } default: { throw LCompilersException("LLVMUtils::deepcopy isn't implemented for " + ASRUtils::type_to_str_python(asr_type)); @@ -611,12 +616,12 @@ namespace LFortran { } void LLVMList::list_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::List_t* list_type, llvm::Module& module) { + ASR::List_t* list_type, llvm::Module* module) { list_deepcopy(src, dest, list_type->m_type, module); } void LLVMList::list_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::ttype_t* element_type, llvm::Module& module) { + ASR::ttype_t* element_type, llvm::Module* module) { LFORTRAN_ASSERT(src->getType() == dest->getType()); std::string src_type_code = ASRUtils::get_type_code(element_type); llvm::Value* src_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(src)); @@ -629,7 +634,7 @@ namespace LFortran { int32_t type_size = std::get<1>(typecode2listtype[src_type_code]); llvm::Value* arg_size = builder->CreateMul(llvm::ConstantInt::get(context, llvm::APInt(32, type_size)), src_capacity); - llvm::Value* copy_data = LLVM::lfortran_malloc(context, module, *builder, + llvm::Value* copy_data = LLVM::lfortran_malloc(context, *module, *builder, arg_size); llvm::Type* el_type = std::get<2>(typecode2listtype[src_type_code]); copy_data = builder->CreateBitCast(copy_data, el_type->getPointerTo()); @@ -700,12 +705,12 @@ namespace LFortran { llvm::Value* src_key_list = get_key_list(src); llvm::Value* dest_key_list = get_key_list(dest); llvm_utils->list_api->list_deepcopy(src_key_list, dest_key_list, - dict_type->m_key_type, *module); + dict_type->m_key_type, module); llvm::Value* src_value_list = get_value_list(src); llvm::Value* dest_value_list = get_value_list(dest); llvm_utils->list_api->list_deepcopy(src_value_list, dest_value_list, - dict_type->m_value_type, *module); + dict_type->m_value_type, module); llvm::Value* src_key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(src)); llvm::Value* dest_key_mask_ptr = get_pointer_to_keymask(dest); @@ -724,7 +729,9 @@ namespace LFortran { void LLVMDictSeparateChaining::deepcopy_key_value_pair_linked_list( llvm::Value* srci, llvm::Value* desti, llvm::Value* dest_key_value_pairs, llvm::Value* src_capacity, ASR::Dict_t* dict_type, llvm::Module* module) { + llvm::AllocaInst *src_itr_copy = nullptr, *dest_itr_copy = nullptr, *next_ptr_copy = nullptr; if( !are_iterators_set ) { + src_itr_copy = src_itr, dest_itr_copy = dest_itr, next_ptr_copy = next_ptr; src_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); dest_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); next_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); @@ -768,8 +775,8 @@ namespace LFortran { } llvm::Value* dest_key_ptr = llvm_utils->create_gep(curr_dest, 0); llvm::Value* dest_value_ptr = llvm_utils->create_gep(curr_dest, 1); - llvm_utils->deepcopy(src_key, dest_key_ptr, dict_type->m_key_type, *module); - llvm_utils->deepcopy(src_value, dest_value_ptr, dict_type->m_value_type, *module); + llvm_utils->deepcopy(src_key, dest_key_ptr, dict_type->m_key_type, module); + llvm_utils->deepcopy(src_value, dest_value_ptr, dict_type->m_value_type, module); llvm::Value* src_next_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(curr_src, 2)); llvm::Value* curr_dest_next_ptr = llvm_utils->create_gep(curr_dest, 2); @@ -807,6 +814,9 @@ namespace LFortran { // end llvm_utils->start_new_block(loopend); + if( !are_iterators_set ) { + src_itr = src_itr_copy, dest_itr = dest_itr_copy, next_ptr = next_ptr_copy; + } } void LLVMDictSeparateChaining::write_key_value_pair_linked_list( @@ -892,7 +902,9 @@ namespace LFortran { dest_key_value_pairs = builder->CreateBitCast( dest_key_value_pairs, get_key_value_pair_type(dict_type->m_key_type, dict_type->m_value_type)->getPointerTo()); + llvm::AllocaInst *copy_itr_copy = nullptr; if( !are_iterators_set ) { + copy_itr_copy = copy_itr; copy_itr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); } llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); @@ -949,6 +961,9 @@ namespace LFortran { // end llvm_utils->start_new_block(loopend); LLVM::CreateStore(*builder, dest_key_value_pairs, get_pointer_to_key_value_pairs(dest)); + if( !are_iterators_set ) { + copy_itr = copy_itr_copy; + } } void LLVMList::check_index_within_bounds(llvm::Value* /*list*/, llvm::Value* /*pos*/) { @@ -957,7 +972,7 @@ namespace LFortran { void LLVMList::write_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, ASR::ttype_t* asr_type, - llvm::Module& module, bool check_index_bound) { + llvm::Module* module, bool check_index_bound) { if( check_index_bound ) { check_index_within_bounds(list, pos); } @@ -981,6 +996,7 @@ namespace LFortran { } llvm::Value* LLVMDictSeparateChaining::get_pointer_to_keymask(llvm::Value* dict) { + std::cout<<"dict->type: "<getType()->isPointerTy()<<" "<getType()->getContainedType(0)->isStructTy()<create_gep(dict, 4); } @@ -1323,9 +1339,9 @@ namespace LFortran { this->resolve_collision(capacity, key_hash, key, key_list, key_mask, *module, key_asr_type); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm_utils->list_api->write_item(key_list, pos, key, - key_asr_type, *module, false); + key_asr_type, module, false); llvm_utils->list_api->write_item(value_list, pos, value, - value_asr_type, *module, false); + value_asr_type, module, false); llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, pos)); llvm::Value* is_slot_empty = builder->CreateICmpEQ(key_mask_value, @@ -1352,9 +1368,9 @@ namespace LFortran { this->resolve_collision(capacity, key_hash, key, key_list, key_mask, *module, key_asr_type); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm_utils->list_api->write_item(key_list, pos, key, - key_asr_type, *module, false); + key_asr_type, module, false); llvm_utils->list_api->write_item(value_list, pos, value, - value_asr_type, *module, false); + value_asr_type, module, false); llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, pos)); @@ -1407,8 +1423,8 @@ namespace LFortran { llvm::Value* malloc_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), kv_struct_size); llvm::Value* new_kv_struct_i8 = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); llvm::Value* new_kv_struct = builder->CreateBitCast(new_kv_struct_i8, kv_struct_type->getPointerTo()); - llvm_utils->deepcopy(key, llvm_utils->create_gep(new_kv_struct, 0), key_asr_type, *module); - llvm_utils->deepcopy(value, llvm_utils->create_gep(new_kv_struct, 1), value_asr_type, *module); + llvm_utils->deepcopy(key, llvm_utils->create_gep(new_kv_struct, 0), key_asr_type, module); + llvm_utils->deepcopy(value, llvm_utils->create_gep(new_kv_struct, 1), value_asr_type, module); LLVM::CreateStore(*builder, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), llvm_utils->create_gep(new_kv_struct, 2)); @@ -1420,8 +1436,8 @@ namespace LFortran { llvm_utils->start_new_block(elseBB); { llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); - llvm_utils->deepcopy(key, llvm_utils->create_gep(kv_struct, 0), key_asr_type, *module); - llvm_utils->deepcopy(value, llvm_utils->create_gep(kv_struct, 1), value_asr_type, *module); + llvm_utils->deepcopy(key, llvm_utils->create_gep(kv_struct, 0), key_asr_type, module); + llvm_utils->deepcopy(value, llvm_utils->create_gep(kv_struct, 1), value_asr_type, module); } llvm_utils->start_new_block(mergeBB); llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); @@ -1766,10 +1782,10 @@ namespace LFortran { llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm::Value* key_dest = llvm_utils->list_api->read_item(new_key_list, pos, true, false); - llvm_utils->deepcopy(key, key_dest, key_asr_type, *module); + llvm_utils->deepcopy(key, key_dest, key_asr_type, module); llvm::Value* value_dest = llvm_utils->list_api->read_item(new_value_list, pos, true, false); - llvm_utils->deepcopy(value, value_dest, value_asr_type, *module); + llvm_utils->deepcopy(value, value_dest, value_asr_type, module); llvm::Value* linear_prob_happened = builder->CreateICmpNE(key_hash, pos); llvm::Value* set_max_2 = builder->CreateSelect(linear_prob_happened, @@ -2225,21 +2241,21 @@ namespace LFortran { } void LLVMList::append(llvm::Value* list, llvm::Value* item, - ASR::ttype_t* asr_type, llvm::Module& module) { + ASR::ttype_t* asr_type, llvm::Module* module) { llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_current_capacity(list)); std::string type_code = ASRUtils::get_type_code(asr_type); int type_size = std::get<1>(typecode2listtype[type_code]); llvm::Type* el_type = std::get<2>(typecode2listtype[type_code]); resize_if_needed(list, current_end_point, current_capacity, - type_size, el_type, module); + type_size, el_type, *module); write_item(list, current_end_point, item, asr_type, module); shift_end_point_by_one(list); } void LLVMList::insert_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, ASR::ttype_t* asr_type, - llvm::Module& module) { + llvm::Module* module) { std::string type_code = ASRUtils::get_type_code(asr_type); llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); @@ -2248,7 +2264,7 @@ namespace LFortran { int type_size = std::get<1>(typecode2listtype[type_code]); llvm::Type* el_type = std::get<2>(typecode2listtype[type_code]); resize_if_needed(list, current_end_point, current_capacity, - type_size, el_type, module); + type_size, el_type, *module); /* While loop equivalent in C++: * end_point // nth index of list @@ -2514,7 +2530,7 @@ namespace LFortran { } void LLVMTuple::tuple_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::Tuple_t* tuple_type, llvm::Module& module) { + ASR::Tuple_t* tuple_type, llvm::Module* module) { LFORTRAN_ASSERT(src->getType() == dest->getType()); for( size_t i = 0; i < tuple_type->n_type; i++ ) { llvm::Value* src_item = read_item(src, i, LLVM::is_llvm_struct( diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index f365eb82d3..548b3c3fda 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -57,7 +57,8 @@ namespace LFortran { return ASR::is_a(*asr_type) || ASR::is_a(*asr_type) || ASR::is_a(*asr_type) || - ASR::is_a(*asr_type); + ASR::is_a(*asr_type) || + ASR::is_a(*asr_type); } } @@ -107,7 +108,7 @@ namespace LFortran { void reset_iterators(); void deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::ttype_t* asr_type, llvm::Module& module); + ASR::ttype_t* asr_type, llvm::Module* module); }; // LLVMUtils @@ -150,11 +151,11 @@ namespace LFortran { void list_deepcopy(llvm::Value* src, llvm::Value* dest, ASR::List_t* list_type, - llvm::Module& module); + llvm::Module* module); void list_deepcopy(llvm::Value* src, llvm::Value* dest, ASR::ttype_t* element_type, - llvm::Module& module); + llvm::Module* module); llvm::Value* read_item(llvm::Value* list, llvm::Value* pos, bool get_pointer=false, bool check_index_bound=true); @@ -165,17 +166,17 @@ namespace LFortran { void write_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, ASR::ttype_t* asr_type, - llvm::Module& module, bool check_index_bound=true); + llvm::Module* module, bool check_index_bound=true); void write_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, bool check_index_bound=true); void append(llvm::Value* list, llvm::Value* item, - ASR::ttype_t* asr_type, llvm::Module& module); + ASR::ttype_t* asr_type, llvm::Module* module); void insert_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, ASR::ttype_t* asr_type, - llvm::Module& module); + llvm::Module* module); void remove(llvm::Value* list, llvm::Value* item, ASR::ttype_t* item_type, llvm::Module& module); @@ -216,7 +217,7 @@ namespace LFortran { bool get_pointer=false); void tuple_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::Tuple_t* type_code, llvm::Module& module); + ASR::Tuple_t* type_code, llvm::Module* module); llvm::Value* check_tuple_equality(llvm::Value* t1, llvm::Value* t2, ASR::Tuple_t* tuple_type, llvm::LLVMContext& context, diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 350b091518..ec644e05ad 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -445,7 +445,7 @@ class CommonVisitor : public AST::BaseVisitor { std::map &ast_overload; std::string parent_dir; Vec *current_body; - ASR::ttype_t* ann_assign_target_type; + ASR::ttype_t *ann_assign_target_type, *assign_target_type, *subscript_value_type; std::map generic_func_nums; std::map> generic_func_subs; @@ -455,7 +455,8 @@ class CommonVisitor : public AST::BaseVisitor { std::map &ast_overload, std::string parent_dir) : diag{diagnostics}, al{al}, current_scope{symbol_table}, main_module{main_module}, ast_overload{ast_overload}, parent_dir{parent_dir}, - current_body{nullptr}, ann_assign_target_type{nullptr} { + current_body{nullptr}, ann_assign_target_type{nullptr}, + assign_target_type{nullptr}, subscript_value_type{nullptr} { current_module_dependencies.reserve(al, 4); } @@ -813,6 +814,20 @@ class CommonVisitor : public AST::BaseVisitor { } + ASR::asr_t* ignore_return_value_util(const Location &loc, ASR::ttype_t* a_type, + ASR::asr_t* func_call_asr) { + std::string dummy_ret_name = current_scope->get_unique_name("__lcompilers_dummy"); + ASR::asr_t* variable_asr = ASR::make_Variable_t(al, loc, current_scope, + s2c(al, dummy_ret_name), ASR::intentType::Local, + nullptr, nullptr, ASR::storage_typeType::Default, + a_type, ASR::abiType::Source, ASR::accessType::Public, + ASR::presenceType::Required, false); + ASR::symbol_t* variable_sym = ASR::down_cast(variable_asr); + current_scope->add_symbol(dummy_ret_name, variable_sym); + ASR::expr_t* variable_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, variable_sym)); + return ASR::make_Assignment_t(al, loc, variable_var, ASRUtils::EXPR(func_call_asr), nullptr); + } + // Function to create appropriate call based on symbol type. If it is external // generic symbol then it changes the name accordingly. ASR::asr_t* make_call_helper(Allocator &al, ASR::symbol_t* s, SymbolTable *current_scope, @@ -930,16 +945,7 @@ class CommonVisitor : public AST::BaseVisitor { s_generic, args_new.p, args_new.size(), a_type, value, nullptr); if( ignore_return_value ) { - std::string dummy_ret_name = current_scope->get_unique_name("__lcompilers_dummy"); - ASR::asr_t* variable_asr = ASR::make_Variable_t(al, loc, current_scope, - s2c(al, dummy_ret_name), ASR::intentType::Local, - nullptr, nullptr, ASR::storage_typeType::Default, - a_type, ASR::abiType::Source, ASR::accessType::Public, - ASR::presenceType::Required, false); - ASR::symbol_t* variable_sym = ASR::down_cast(variable_asr); - current_scope->add_symbol(dummy_ret_name, variable_sym); - ASR::expr_t* variable_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, variable_sym)); - return ASR::make_Assignment_t(al, loc, variable_var, ASRUtils::EXPR(func_call_asr), nullptr); + return ignore_return_value_util(loc, a_type, func_call_asr); } else { return func_call_asr; } @@ -2387,6 +2393,7 @@ class CommonVisitor : public AST::BaseVisitor { if( !visit_SubscriptIndices(x.m_slice, args, value, type, is_item, x.base.base.loc) ) { + subscript_value_type = type; return ; } @@ -2407,6 +2414,7 @@ class CommonVisitor : public AST::BaseVisitor { tmp = ASR::make_ArraySection_t(al, x.base.base.loc, v_Var, args.p, args.size(), type, nullptr); } + subscript_value_type = ASRUtils::expr_type(value); } }; @@ -3083,69 +3091,63 @@ class BodyVisitor : public CommonVisitor { } void visit_Assign(const AST::Assign_t &x) { - ASR::expr_t *target, *assign_value = nullptr, *tmp_value; - this->visit_expr(*x.m_value); - if (tmp) { - // This happens if `m.m_value` is `empty`, such as in: - // a = empty(16) - // We skip this statement for now, the array is declared - // by the annotation. - // TODO: enforce that empty(), ones(), zeros() is called - // for every declaration. - assign_value = ASRUtils::EXPR(tmp); - } + ASR::expr_t *target, *tmp_value = nullptr; for (size_t i=0; i(*x.m_targets[i])) { AST::Subscript_t *sb = AST::down_cast(x.m_targets[i]); - if (AST::is_a(*sb->m_value)) { - std::string name = AST::down_cast(sb->m_value)->m_id; - ASR::symbol_t *s = current_scope->resolve_symbol(name); - if (!s) { - throw SemanticError("Variable: '" + name + "' is not declared", - x.base.base.loc); + subscript_value_type = nullptr; + visit_Subscript(*sb); + ASR::ttype_t* type = subscript_value_type; + ASR::expr_t* subscript_expr = ASRUtils::EXPR(tmp); + if (ASR::is_a(*type)) { + // dict insert case; + this->visit_expr(*sb->m_slice); + ASR::expr_t *key = ASRUtils::EXPR(tmp); + ASR::ttype_t *key_type = ASR::down_cast(type)->m_key_type; + ASR::ttype_t *value_type = ASR::down_cast(type)->m_value_type; + if (!ASRUtils::check_equal_type(ASRUtils::expr_type(key), key_type)) { + std::string ktype = ASRUtils::type_to_str_python(ASRUtils::expr_type(key)); + std::string totype = ASRUtils::type_to_str_python(key_type); + diag.add(diag::Diagnostic( + "Type mismatch in dictionary key, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch (found: '" + ktype + "', expected: '" + totype + "')", + {key->base.loc}) + }) + ); + throw SemanticAbort(); } - ASR::Variable_t *v = ASR::down_cast(s); - ASR::ttype_t *type = v->m_type; - if (ASR::is_a(*type)) { - // dict insert case; - this->visit_expr(*sb->m_slice); - ASR::expr_t *key = ASRUtils::EXPR(tmp); - ASR::ttype_t *key_type = ASR::down_cast(type)->m_key_type; - ASR::ttype_t *value_type = ASR::down_cast(type)->m_value_type; - if (!ASRUtils::check_equal_type(ASRUtils::expr_type(key), key_type)) { - std::string ktype = ASRUtils::type_to_str_python(ASRUtils::expr_type(key)); - std::string totype = ASRUtils::type_to_str_python(key_type); - diag.add(diag::Diagnostic( - "Type mismatch in dictionary key, the types must be compatible", - diag::Level::Error, diag::Stage::Semantic, { - diag::Label("type mismatch (found: '" + ktype + "', expected: '" + totype + "')", - {key->base.loc}) - }) - ); - throw SemanticAbort(); - } - if (!ASRUtils::check_equal_type(ASRUtils::expr_type(tmp_value), value_type)) { - std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(tmp_value)); - std::string totype = ASRUtils::type_to_str_python(value_type); - diag.add(diag::Diagnostic( - "Type mismatch in dictionary value, the types must be compatible", - diag::Level::Error, diag::Stage::Semantic, { - diag::Label("type mismatch (found: '" + vtype + "', expected: '" + totype + "')", - {tmp_value->base.loc}) - }) - ); - throw SemanticAbort(); - } - ASR::expr_t* se = ASR::down_cast( - ASR::make_Var_t(al, x.base.base.loc, s)); - tmp = nullptr; - tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, se, key, tmp_value)); - continue; - } else if (ASRUtils::is_immutable(type)) { - throw SemanticError("'" + ASRUtils::type_to_str_python(type) + "' object does not support" - " item assignment", x.base.base.loc); + assign_target_type = value_type; + this->visit_expr(*x.m_value); + if (tmp) { + // This happens if `m.m_value` is `empty`, such as in: + // a = empty(16) + // We skip this statement for now, the array is declared + // by the annotation. + // TODO: enforce that empty(), ones(), zeros() is called + // for every declaration. + tmp_value = ASRUtils::EXPR(tmp); + } + if (!ASRUtils::check_equal_type(ASRUtils::expr_type(tmp_value), value_type)) { + std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(tmp_value)); + std::string totype = ASRUtils::type_to_str_python(value_type); + diag.add(diag::Diagnostic( + "Type mismatch in dictionary value, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch (found: '" + vtype + "', expected: '" + totype + "')", + {tmp_value->base.loc}) + }) + ); + throw SemanticAbort(); } + LFORTRAN_ASSERT(ASR::is_a(*subscript_expr)); + ASR::DictItem_t* dict_item = ASR::down_cast(subscript_expr); + tmp = nullptr; + tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, dict_item->m_a, key, tmp_value)); + continue ; + } else if (ASRUtils::is_immutable(type)) { + throw SemanticError("'" + ASRUtils::type_to_str_python(type) + "' object does not support" + " item assignment", x.base.base.loc); } } else if (AST::is_a(*x.m_targets[i])) { AST::Attribute_t *attr = AST::down_cast(x.m_targets[i]); @@ -3158,20 +3160,43 @@ class BodyVisitor : public CommonVisitor { } ASR::Variable_t *v = ASR::down_cast(s); ASR::ttype_t *type = v->m_type; + assign_target_type = type; + this->visit_expr(*x.m_value); + if (tmp) { + // This happens if `m.m_value` is `empty`, such as in: + // a = empty(16) + // We skip this statement for now, the array is declared + // by the annotation. + // TODO: enforce that empty(), ones(), zeros() is called + // for every declaration. + tmp_value = ASRUtils::EXPR(tmp); + } if (ASRUtils::is_immutable(type)) { throw SemanticError("readonly attribute", x.base.base.loc); } } } - if (!tmp_value) continue; + this->visit_expr(*x.m_targets[i]); target = ASRUtils::EXPR(tmp); + assign_target_type = ASRUtils::expr_type(target); + this->visit_expr(*x.m_value); + if (tmp) { + // This happens if `m.m_value` is `empty`, such as in: + // a = empty(16) + // We skip this statement for now, the array is declared + // by the annotation. + // TODO: enforce that empty(), ones(), zeros() is called + // for every declaration. + tmp_value = ASRUtils::EXPR(tmp); + } + if (!tmp_value) continue; ASR::ttype_t *target_type = ASRUtils::expr_type(target); - ASR::ttype_t *value_type = ASRUtils::expr_type(assign_value); + ASR::ttype_t *value_type = ASRUtils::expr_type(tmp_value); // Check if the target parameter type can be assigned with zero if (ASR::is_a(*target_type) - && ASR::is_a(*assign_value)) { - ASR::IntegerConstant_t *value_constant = ASR::down_cast(assign_value); + && ASR::is_a(*tmp_value)) { + ASR::IntegerConstant_t *value_constant = ASR::down_cast(tmp_value); if (value_constant->m_n == 0) { if (!ASRUtils::has_trait(ASR::down_cast(target_type), ASR::traitType::SupportsZero)) { @@ -3179,7 +3204,7 @@ class BodyVisitor : public CommonVisitor { "to be assignable with zero.", target_type->base.loc); } - tmp = ASR::make_Assignment_t(al, x.base.base.loc, target, assign_value, nullptr); + tmp = ASR::make_Assignment_t(al, x.base.base.loc, target, tmp_value, nullptr); return ; } } @@ -3636,10 +3661,17 @@ class BodyVisitor : public CommonVisitor { void visit_Dict(const AST::Dict_t &x) { LFORTRAN_ASSERT(x.n_keys == x.n_values); - if( x.n_keys == 0 && ann_assign_target_type != nullptr ) { - tmp = ASR::make_DictConstant_t(al, x.base.base.loc, nullptr, 0, - nullptr, 0, ann_assign_target_type); - return ; + if( x.n_keys == 0 ) { + if( ann_assign_target_type != nullptr ) { + tmp = ASR::make_DictConstant_t(al, x.base.base.loc, nullptr, 0, + nullptr, 0, ann_assign_target_type); + return ; + } + if( assign_target_type != nullptr ) { + tmp = ASR::make_DictConstant_t(al, x.base.base.loc, nullptr, 0, + nullptr, 0, assign_target_type); + return ; + } } Vec keys; keys.reserve(al, x.n_keys); @@ -4025,24 +4057,20 @@ class BodyVisitor : public CommonVisitor { } } else if (AST::is_a(*c->m_func)) { AST::Attribute_t *at = AST::down_cast(c->m_func); - if (AST::is_a(*at->m_value)) { - std::string value = AST::down_cast(at->m_value)->m_id; - ASR::symbol_t *t = current_scope->resolve_symbol(value); - if (!t) { - throw SemanticError("'" + value + "' is not defined in the scope", - x.base.base.loc); - } - Vec elements; - elements.reserve(al, c->n_args); - for (size_t i = 0; i < c->n_args; ++i) { - visit_expr(*c->m_args[i]); - elements.push_back(al, ASRUtils::EXPR(tmp)); - } - ASR::expr_t *te = ASR::down_cast( - ASR::make_Var_t(al, x.base.base.loc, t)); - handle_attribute(te, at->m_attr, x.base.base.loc, elements); - return; + Vec elements; + elements.reserve(al, c->n_args); + for (size_t i = 0; i < c->n_args; ++i) { + visit_expr(*c->m_args[i]); + elements.push_back(al, ASRUtils::EXPR(tmp)); + } + visit_expr(*at->m_value); + ASR::expr_t *te = ASRUtils::EXPR(tmp); + handle_attribute(te, at->m_attr, x.base.base.loc, elements); + if( ASR::is_a(*tmp) ) { + ASR::expr_t* tmp_expr = ASRUtils::EXPR(tmp); + tmp = ignore_return_value_util(x.base.base.loc, ASRUtils::expr_type(tmp_expr), tmp); } + return ; } else { throw SemanticError("Only Name/Attribute supported in Call", x.base.base.loc); @@ -4584,6 +4612,18 @@ class BodyVisitor : public CommonVisitor { throw SemanticError("'str' object has no attribute '" + std::string(at->m_attr) + "'", x.base.base.loc); } + } else if (AST::is_a(*at->m_value)) { + AST::Subscript_t* subscript = AST::down_cast(at->m_value); + visit_Subscript(*subscript); + ASR::expr_t* subscript_expr = ASRUtils::EXPR(tmp); + Vec elements; + elements.reserve(al, args.size()); + for (size_t i = 0; i < args.size(); ++i) { + elements.push_back(al, args.p[i].m_value); + } + handle_attribute(subscript_expr, std::string(at->m_attr), + at->m_value->base.loc, elements); + return ; } else { throw SemanticError("Only Name type and constant integers supported in Call", x.base.base.loc); diff --git a/tests/reference/asr-test_dict2-4587f02.json b/tests/reference/asr-test_dict2-4587f02.json index 296c0bbe09..270d51c545 100644 --- a/tests/reference/asr-test_dict2-4587f02.json +++ b/tests/reference/asr-test_dict2-4587f02.json @@ -8,6 +8,6 @@ "stdout": null, "stdout_hash": null, "stderr": "asr-test_dict2-4587f02.stderr", - "stderr_hash": "00d00b6323fa903c677ea2bf60b453ed2ad4cc0f0aa1886154359dd8", + "stderr_hash": "9de5d75622644a0cb98bdd3f73249772c25c293f508343b31cc34607", "returncode": 2 } \ No newline at end of file diff --git a/tests/reference/asr-test_dict2-4587f02.stderr b/tests/reference/asr-test_dict2-4587f02.stderr index 465bf0562d..dde39a40a8 100644 --- a/tests/reference/asr-test_dict2-4587f02.stderr +++ b/tests/reference/asr-test_dict2-4587f02.stderr @@ -1,5 +1,5 @@ -semantic error: Type mismatch in dictionary key, the types must be compatible +semantic error: Key type should be 'str' instead of 'i32' --> tests/errors/test_dict2.py:4:7 | 4 | y[1] = -3 - | ^ type mismatch (found: 'i32', expected: 'str') + | ^