diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 2a0e82546b..236050b10f 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -311,7 +311,9 @@ RUN(NAME structs_07 LABELS llvm c EXTRAFILES structs_07b.c) RUN(NAME structs_08 LABELS cpython llvm c) RUN(NAME structs_09 LABELS cpython llvm c) -RUN(NAME structs_10 LABELS cpython llvm c) +# TODO: Re-enable c in structs_10 +RUN(NAME structs_10 LABELS cpython llvm) +RUN(NAME structs_11 LABELS cpython llvm) RUN(NAME structs_12 LABELS cpython llvm c) RUN(NAME structs_13 LABELS llvm c EXTRAFILES structs_13b.c) diff --git a/integration_tests/structs_10.py b/integration_tests/structs_10.py index 600960e286..bab2800e61 100644 --- a/integration_tests/structs_10.py +++ b/integration_tests/structs_10.py @@ -2,17 +2,25 @@ from numpy import empty, float64 @dataclass -class MatVec: +class Mat: mat: f64[2, 2] + +@dataclass +class Vec: vec: f64[2] +@dataclass +class MatVec: + mat: Mat + vec: Vec + def rotate(mat_vec: MatVec) -> f64[2]: rotated_vec: f64[2] = empty(2, dtype=float64) - rotated_vec[0] = mat_vec.mat[0, 0] * mat_vec.vec[0] + mat_vec.mat[0, 1] * mat_vec.vec[1] - rotated_vec[1] = mat_vec.mat[1, 0] * mat_vec.vec[0] + mat_vec.mat[1, 1] * mat_vec.vec[1] + rotated_vec[0] = mat_vec.mat.mat[0, 0] * mat_vec.vec.vec[0] + mat_vec.mat.mat[0, 1] * mat_vec.vec.vec[1] + rotated_vec[1] = mat_vec.mat.mat[1, 0] * mat_vec.vec.vec[0] + mat_vec.mat.mat[1, 1] * mat_vec.vec.vec[1] return rotated_vec -def test_rotate_by_90(): +def create_MatVec_obj() -> MatVec: mat: f64[2, 2] = empty((2, 2), dtype=float64) vec: f64[2] = empty(2, dtype=float64) mat[0, 0] = 0.0 @@ -21,9 +29,15 @@ def test_rotate_by_90(): mat[1, 1] = 0.0 vec[0] = 1.0 vec[1] = 0.0 - mat_vec: MatVec = MatVec(mat, vec) - print(mat_vec.mat[0, 0], mat_vec.mat[0, 1], mat_vec.mat[1, 0], mat_vec.mat[1, 1]) - print(mat_vec.vec[0], mat_vec.vec[1]) + mat_s: Mat = Mat(mat) + vec_s: Vec = Vec(vec) + mat_vec: MatVec = MatVec(mat_s, vec_s) + return mat_vec + +def test_rotate_by_90(): + mat_vec: MatVec = create_MatVec_obj() + print(mat_vec.mat.mat[0, 0], mat_vec.mat.mat[0, 1], mat_vec.mat.mat[1, 0], mat_vec.mat.mat[1, 1]) + print(mat_vec.vec.vec[0], mat_vec.vec.vec[1]) rotated_vec: f64[2] = rotate(mat_vec) print(rotated_vec[0], rotated_vec[1]) assert abs(rotated_vec[0] - 0.0) <= 1e-12 diff --git a/integration_tests/structs_11.py b/integration_tests/structs_11.py new file mode 100644 index 0000000000..4fc71e595e --- /dev/null +++ b/integration_tests/structs_11.py @@ -0,0 +1,18 @@ +from ltypes import i32, f64, dataclass + +@dataclass +class A: + x: i32 + y: f64 + +def f(x_: i32, y_: f64) -> A: + a_struct: A = A(x_, y_) + return a_struct + +def test_struct_return(): + b: A = f(0, 1.0) + print(b.x, b.y) + assert b.x == 0 + assert b.y == 1.0 + +test_struct_return() diff --git a/src/libasr/CMakeLists.txt b/src/libasr/CMakeLists.txt index bcf5bcd016..fa032b21b2 100644 --- a/src/libasr/CMakeLists.txt +++ b/src/libasr/CMakeLists.txt @@ -34,6 +34,7 @@ set(SRC pass/select_case.cpp pass/implied_do_loops.cpp pass/array_op.cpp + pass/subroutine_from_function.cpp pass/class_constructor.cpp pass/arr_slice.cpp pass/print_arr.cpp diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 0fbf46f489..fd0fcc3336 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -270,6 +270,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_utils->tuple_api = tuple_api.get(); llvm_utils->list_api = list_api.get(); llvm_utils->dict_api = nullptr; + llvm_utils->arr_api = arr_descr.get(); } llvm::Value* CreateLoad(llvm::Value *x) { @@ -1383,7 +1384,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*x.m_args[i]); llvm::Value* item = tmp; llvm::Value* pos = llvm::ConstantInt::get(context, llvm::APInt(32, i)); - list_api->write_item(const_list, pos, item, list_type->m_type, false, *module); + list_api->write_item(const_list, pos, item, list_type->m_type, + false, module.get(), name2memidx); } ptr_loads = ptr_loads_copy; tmp = const_list; @@ -1416,7 +1418,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor visit_expr(*x.m_values[i]); llvm::Value* value = tmp; llvm_utils->dict_api->write_item(const_dict, key, value, module.get(), - x_dict->m_key_type, x_dict->m_value_type); + x_dict->m_key_type, x_dict->m_value_type, name2memidx); } ptr_loads = ptr_loads_copy; tmp = const_dict; @@ -1484,7 +1486,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value *item = tmp; ptr_loads = ptr_loads_copy; - list_api->append(plist, item, asr_list->m_type, *module); + list_api->append(plist, item, asr_list->m_type, module.get(), name2memidx); } void visit_UnionRef(const ASR::UnionRef_t& x) { @@ -1609,7 +1611,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value *item = tmp; ptr_loads = ptr_loads_copy; - list_api->insert_item(plist, pos, item, asr_list->m_type, *module); + list_api->insert_item(plist, pos, item, asr_list->m_type, module.get(), name2memidx); } void visit_DictInsert(const ASR::DictInsert_t& x) { @@ -1631,7 +1633,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor set_dict_api(dict_type); llvm_utils->dict_api->write_item(pdict, key, value, module.get(), dict_type->m_key_type, - dict_type->m_value_type); + dict_type->m_value_type, name2memidx); } void visit_ListRemove(const ASR::ListRemove_t& x) { @@ -2571,6 +2573,49 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor m_dims_local, n_dims_local, a_kind_local); } + void fill_array_details_(llvm::Value* ptr, ASR::dimension_t* m_dims, + size_t n_dims, bool is_malloc_array_type, bool is_array_type, + bool is_list, ASR::ttype_t* m_type) { + if( is_malloc_array_type && + m_type->type != ASR::ttypeType::Pointer && + !is_list ) { + arr_descr->fill_dimension_descriptor(ptr, n_dims); + } + if( is_array_type && !is_malloc_array_type && + m_type->type != ASR::ttypeType::Pointer && + !is_list ) { + ASR::ttype_t* asr_data_type = ASRUtils::duplicate_type_without_dims(al, m_type, m_type->base.loc); + llvm::Type* llvm_data_type = get_type_from_ttype_t_util(asr_data_type); + fill_array_details(ptr, llvm_data_type, m_dims, n_dims); + } + if( is_array_type && is_malloc_array_type && + m_type->type != ASR::ttypeType::Pointer && + !is_list ) { + // Set allocatable arrays as unallocated + arr_descr->set_is_allocated_flag(ptr, 0); + } + } + + void allocate_array_members_of_struct(llvm::Value* ptr, ASR::ttype_t* asr_type) { + LFORTRAN_ASSERT(ASR::is_a(*asr_type)); + ASR::Struct_t* struct_t = ASR::down_cast(asr_type); + ASR::StructType_t* struct_type_t = ASR::down_cast(struct_t->m_derived_type); + std::string struct_type_name = struct_type_t->m_name; + for( auto item: struct_type_t->m_symtab->get_scope() ) { + ASR::ttype_t* symbol_type = ASRUtils::symbol_type(item.second); + int idx = name2memidx[struct_type_name][item.first]; + llvm::Value* ptr_member = llvm_utils->create_gep(ptr, idx); + if( ASRUtils::is_array(symbol_type) ) { + // Assume that struct member array is not allocatable + ASR::dimension_t* m_dims = nullptr; + size_t n_dims = ASRUtils::extract_dimensions_from_ttype(symbol_type, m_dims); + fill_array_details_(ptr_member, m_dims, n_dims, false, true, false, symbol_type); + } else if( ASR::is_a(*symbol_type) ) { + allocate_array_members_of_struct(ptr_member, symbol_type); + } + } + } + template void declare_vars(const T &x) { llvm::Value *target_var; @@ -2626,6 +2671,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } llvm::AllocaInst *ptr = builder->CreateAlloca(type, nullptr, v->m_name); + if( ASR::is_a(*v->m_type) ) { + allocate_array_members_of_struct(ptr, v->m_type); + } if (emit_debug_info) { // Reset the debug location builder->SetCurrentDebugLocation(nullptr); @@ -2659,24 +2707,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } llvm_symtab[h] = ptr; - if( is_malloc_array_type && - v->m_type->type != ASR::ttypeType::Pointer && - !is_list ) { - arr_descr->fill_dimension_descriptor(ptr, n_dims); - } - if( is_array_type && !is_malloc_array_type && - v->m_type->type != ASR::ttypeType::Pointer && - !is_list ) { - ASR::ttype_t* asr_data_type = ASRUtils::duplicate_type_without_dims(al, v->m_type, v->m_type->base.loc); - llvm::Type* llvm_data_type = get_type_from_ttype_t_util(asr_data_type); - fill_array_details(ptr, llvm_data_type, m_dims, n_dims); - } - if( is_array_type && is_malloc_array_type && - v->m_type->type != ASR::ttypeType::Pointer && - !is_list ) { - // Set allocatable arrays as unallocated - arr_descr->set_is_allocated_flag(ptr, 0); - } + fill_array_details_(ptr, m_dims, n_dims, + is_malloc_array_type, + is_array_type, is_list, v->m_type); if( v->m_symbolic_value != nullptr && !ASR::is_a(*v->m_type)) { target_var = ptr; @@ -3776,7 +3809,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASRUtils::expr_type(x.m_value)); std::string value_type_code = ASRUtils::get_type_code(value_asr_list->m_type); list_api->list_deepcopy(value_list, target_list, - value_asr_list, *module); + value_asr_list, module.get(), + name2memidx); return ; } else if( is_target_tuple && is_value_tuple ) { uint64_t ptr_loads_copy = ptr_loads; @@ -3805,7 +3839,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value* llvm_tuple_i = builder->CreateAlloca(llvm_tuple_i_type, nullptr); ptr_loads = !LLVM::is_llvm_struct(asr_tuple_i_type); visit_expr(*asr_value_tuple->m_elements[i]); - llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, *module); + llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, module.get(), name2memidx); src_deepcopies.push_back(al, llvm_tuple_i); } ASR::TupleConstant_t* asr_target_tuple = ASR::down_cast(x.m_target); @@ -3829,7 +3863,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type, value_tuple_type->n_type); tuple_api->tuple_deepcopy(value_tuple, target_tuple, - value_tuple_type, *module); + value_tuple_type, module.get(), + name2memidx); } return ; } else if( is_target_dict && is_value_dict ) { @@ -3843,7 +3878,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::Dict_t* value_dict_type = ASR::down_cast(asr_value_type); set_dict_api(value_dict_type); llvm_utils->dict_api->dict_deepcopy(value_dict, target_dict, - value_dict_type, module.get()); + value_dict_type, module.get(), name2memidx); return ; } else if( is_target_struct && is_value_struct ) { uint64_t ptr_loads_copy = ptr_loads; @@ -3856,10 +3891,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor is_assignment_target = is_assignment_target_copy; llvm::Value* target_struct = tmp; ptr_loads = ptr_loads_copy; - LLVM::CreateStore(*builder, - LLVM::CreateLoad(*builder, value_struct), - target_struct - ); + llvm_utils->deepcopy(value_struct, target_struct, + asr_target_type, module.get(), name2memidx); return ; } @@ -3973,9 +4006,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor if( ASRUtils::is_array(target_type) && ASRUtils::is_array(value_type) && ASRUtils::check_equal_type(target_type, value_type) ) { - bool create_dim_des_array = !ASR::is_a(*x.m_target); arr_descr->copy_array(value, target, module.get(), - target_type, create_dim_des_array); + target_type, false, false); } else { builder->CreateStore(value, target); } @@ -5975,7 +6007,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } if( ASR::is_a(*arg_type) || ASR::is_a(*arg_type) ) { - llvm_utils->deepcopy(value, target, arg_type, *module); + llvm_utils->deepcopy(value, target, arg_type, module.get(), name2memidx); } else { builder->CreateStore(value, target); } diff --git a/src/libasr/codegen/llvm_array_utils.cpp b/src/libasr/codegen/llvm_array_utils.cpp index df00ad85b3..0f200c3d6f 100644 --- a/src/libasr/codegen/llvm_array_utils.cpp +++ b/src/libasr/codegen/llvm_array_utils.cpp @@ -571,13 +571,16 @@ namespace LFortran { // Shallow copies source array descriptor to destination descriptor void SimpleCMODescriptor::copy_array(llvm::Value* src, llvm::Value* dest, - llvm::Module* module, ASR::ttype_t* asr_data_type, bool create_dim_des_array) { + llvm::Module* module, ASR::ttype_t* asr_data_type, bool create_dim_des_array, + bool reserve_memory) { llvm::Value* num_elements = this->get_array_size(src, nullptr, 4); llvm::Value* first_ptr = this->get_pointer_to_data(dest); llvm::Type* llvm_data_type = tkr2array[ASRUtils::get_type_code(asr_data_type, false, false)].second; - llvm::Value* arr_first = builder->CreateAlloca(llvm_data_type, num_elements); - builder->CreateStore(arr_first, first_ptr); + if( reserve_memory ) { + llvm::Value* arr_first = builder->CreateAlloca(llvm_data_type, num_elements); + builder->CreateStore(arr_first, first_ptr); + } llvm::Value* ptr2firstptr = this->get_pointer_to_data(src); llvm::DataLayout data_layout(module); diff --git a/src/libasr/codegen/llvm_array_utils.h b/src/libasr/codegen/llvm_array_utils.h index 0d638f729d..30d75ba8c4 100644 --- a/src/libasr/codegen/llvm_array_utils.h +++ b/src/libasr/codegen/llvm_array_utils.h @@ -260,7 +260,7 @@ namespace LFortran { virtual void copy_array(llvm::Value* src, llvm::Value* dest, llvm::Module* module, ASR::ttype_t* asr_data_type, - bool create_dim_des_array) = 0; + bool create_dim_des_array, bool reserve_memory) = 0; virtual llvm::Value* get_array_size(llvm::Value* array, llvm::Value* dim, @@ -394,7 +394,7 @@ namespace LFortran { virtual void copy_array(llvm::Value* src, llvm::Value* dest, llvm::Module* module, ASR::ttype_t* asr_data_type, - bool create_dim_des_array); + bool create_dim_des_array, bool reserve_memory); virtual llvm::Value* get_array_size(llvm::Value* array, llvm::Value* dim, diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 97b00a038c..f32adeed75 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1,5 +1,6 @@ #include #include +#include #include namespace LFortran { @@ -298,24 +299,52 @@ namespace LFortran { } void LLVMUtils::deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::ttype_t* asr_type, llvm::Module& module) { + ASR::ttype_t* asr_type, llvm::Module* module, + std::map>& name2memidx) { switch( asr_type->type ) { case ASR::ttypeType::Integer: case ASR::ttypeType::Real: - case ASR::ttypeType::Character: case ASR::ttypeType::Logical: case ASR::ttypeType::Complex: { - LLVM::CreateStore(*builder, src, dest); + if( ASRUtils::is_array(asr_type) ) { + arr_api->copy_array(src, dest, module, asr_type, false, false); + } else { + LLVM::CreateStore(*builder, src, dest); + } break ; }; + case ASR::ttypeType::Character: { + LLVM::CreateStore(*builder, src, dest); + break ; + } case ASR::ttypeType::Tuple: { ASR::Tuple_t* tuple_type = ASR::down_cast(asr_type); - tuple_api->tuple_deepcopy(src, dest, tuple_type, module); + tuple_api->tuple_deepcopy(src, dest, tuple_type, module, name2memidx); break ; } case ASR::ttypeType::List: { ASR::List_t* list_type = ASR::down_cast(asr_type); - list_api->list_deepcopy(src, dest, list_type, module); + list_api->list_deepcopy(src, dest, list_type, module, name2memidx); + break ; + } + case ASR::ttypeType::Struct: { + ASR::Struct_t* struct_t = ASR::down_cast(asr_type); + ASR::StructType_t* struct_type_t = ASR::down_cast( + ASRUtils::symbol_get_past_external(struct_t->m_derived_type)); + std::string der_type_name = std::string(struct_type_t->m_name); + for( auto item: struct_type_t->m_symtab->get_scope() ) { + std::string mem_name = item.first; + int mem_idx = name2memidx[der_type_name][mem_name]; + llvm::Value* src_member = create_gep(src, mem_idx); + if( !LLVM::is_llvm_struct(ASRUtils::symbol_type(item.second)) && + !ASRUtils::is_array(ASRUtils::symbol_type(item.second)) ) { + src_member = LLVM::CreateLoad(*builder, src_member); + } + llvm::Value* dest_member = create_gep(dest, mem_idx); + deepcopy(src_member, dest_member, + ASRUtils::symbol_type(item.second), + module, name2memidx); + } break ; } default: { @@ -611,12 +640,14 @@ namespace LFortran { } void LLVMList::list_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::List_t* list_type, llvm::Module& module) { - list_deepcopy(src, dest, list_type->m_type, module); + ASR::List_t* list_type, llvm::Module* module, + std::map>& name2memidx) { + list_deepcopy(src, dest, list_type->m_type, module, name2memidx); } void LLVMList::list_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::ttype_t* element_type, llvm::Module& module) { + ASR::ttype_t* element_type, llvm::Module* module, + std::map>& name2memidx) { LFORTRAN_ASSERT(src->getType() == dest->getType()); std::string src_type_code = ASRUtils::get_type_code(element_type); llvm::Value* src_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(src)); @@ -629,7 +660,7 @@ namespace LFortran { int32_t type_size = std::get<1>(typecode2listtype[src_type_code]); llvm::Value* arg_size = builder->CreateMul(llvm::ConstantInt::get(context, llvm::APInt(32, type_size)), src_capacity); - llvm::Value* copy_data = LLVM::lfortran_malloc(context, module, *builder, + llvm::Value* copy_data = LLVM::lfortran_malloc(context, *module, *builder, arg_size); llvm::Type* el_type = std::get<2>(typecode2listtype[src_type_code]); copy_data = builder->CreateBitCast(copy_data, el_type->getPointerTo()); @@ -670,9 +701,9 @@ namespace LFortran { llvm_utils->start_new_block(loopbody); { llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); - llvm::Value* srci = read_item(src, pos, false, module, true); - llvm::Value* desti = read_item(dest, pos, false, module, true); - llvm_utils->deepcopy(srci, desti, element_type, module); + llvm::Value* srci = read_item(src, pos, false, *module, true); + llvm::Value* desti = read_item(dest, pos, false, *module, true); + llvm_utils->deepcopy(srci, desti, element_type, module, name2memidx); llvm::Value* tmp = builder->CreateAdd( pos, llvm::ConstantInt::get(context, llvm::APInt(32, 1))); @@ -691,7 +722,8 @@ namespace LFortran { } void LLVMDict::dict_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::Dict_t* dict_type, llvm::Module* module) { + ASR::Dict_t* dict_type, llvm::Module* module, + std::map>& name2memidx) { LFORTRAN_ASSERT(src->getType() == dest->getType()); llvm::Value* src_occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(src)); llvm::Value* dest_occupancy_ptr = get_pointer_to_occupancy(dest); @@ -700,12 +732,13 @@ namespace LFortran { llvm::Value* src_key_list = get_key_list(src); llvm::Value* dest_key_list = get_key_list(dest); llvm_utils->list_api->list_deepcopy(src_key_list, dest_key_list, - dict_type->m_key_type, *module); + dict_type->m_key_type, module, + name2memidx); llvm::Value* src_value_list = get_value_list(src); llvm::Value* dest_value_list = get_value_list(dest); llvm_utils->list_api->list_deepcopy(src_value_list, dest_value_list, - dict_type->m_value_type, *module); + dict_type->m_value_type, module, name2memidx); llvm::Value* src_key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(src)); llvm::Value* dest_key_mask_ptr = get_pointer_to_keymask(dest); @@ -723,7 +756,8 @@ namespace LFortran { void LLVMDictSeparateChaining::deepcopy_key_value_pair_linked_list( llvm::Value* srci, llvm::Value* desti, llvm::Value* dest_key_value_pairs, - llvm::Value* src_capacity, ASR::Dict_t* dict_type, llvm::Module* module) { + llvm::Value* src_capacity, ASR::Dict_t* dict_type, llvm::Module* module, + std::map>& name2memidx) { if( !are_iterators_set ) { src_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); dest_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); @@ -768,8 +802,8 @@ namespace LFortran { } llvm::Value* dest_key_ptr = llvm_utils->create_gep(curr_dest, 0); llvm::Value* dest_value_ptr = llvm_utils->create_gep(curr_dest, 1); - llvm_utils->deepcopy(src_key, dest_key_ptr, dict_type->m_key_type, *module); - llvm_utils->deepcopy(src_value, dest_value_ptr, dict_type->m_value_type, *module); + llvm_utils->deepcopy(src_key, dest_key_ptr, dict_type->m_key_type, module, name2memidx); + llvm_utils->deepcopy(src_value, dest_value_ptr, dict_type->m_value_type, module, name2memidx); llvm::Value* src_next_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(curr_src, 2)); llvm::Value* curr_dest_next_ptr = llvm_utils->create_gep(curr_dest, 2); @@ -811,7 +845,8 @@ namespace LFortran { void LLVMDictSeparateChaining::write_key_value_pair_linked_list( llvm::Value* kv_ll, llvm::Value* dict, llvm::Value* capacity, - ASR::ttype_t* m_key_type, ASR::ttype_t* m_value_type, llvm::Module* module) { + ASR::ttype_t* m_key_type, ASR::ttype_t* m_value_type, llvm::Module* module, + std::map>& name2memidx) { if( !are_iterators_set ) { src_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); } @@ -850,7 +885,8 @@ namespace LFortran { resolve_collision_for_write( dict, key_hash, src_key, src_value, module, - m_key_type, m_value_type); + m_key_type, m_value_type, + name2memidx); llvm::Value* src_next_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(curr_src, 2)); LLVM::CreateStore(*builder, src_next_ptr, src_itr); @@ -864,7 +900,8 @@ namespace LFortran { void LLVMDictSeparateChaining::dict_deepcopy( llvm::Value* src, llvm::Value* dest, - ASR::Dict_t* dict_type, llvm::Module* module) { + ASR::Dict_t* dict_type, llvm::Module* module, + std::map>& name2memidx) { llvm::Value* src_occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(src)); llvm::Value* src_filled_buckets = LLVM::CreateLoad(*builder, get_pointer_to_number_of_filled_buckets(src)); llvm::Value* src_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(src)); @@ -933,7 +970,7 @@ namespace LFortran { llvm::Value* srci = llvm_utils->create_ptr_gep(src_key_value_pairs, itr); llvm::Value* desti = llvm_utils->create_ptr_gep(dest_key_value_pairs, itr); deepcopy_key_value_pair_linked_list(srci, desti, dest_key_value_pairs, - src_capacity, dict_type, module); + src_capacity, dict_type, module, name2memidx); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -992,13 +1029,14 @@ namespace LFortran { void LLVMList::write_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, ASR::ttype_t* asr_type, - bool enable_bounds_checking, llvm::Module& module) { + bool enable_bounds_checking, llvm::Module* module, + std::map>& name2memidx) { if( enable_bounds_checking ) { - check_index_within_bounds(list, pos, module); + check_index_within_bounds(list, pos, *module); } llvm::Value* list_data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list)); llvm::Value* element_ptr = llvm_utils->create_ptr_gep(list_data, pos); - llvm_utils->deepcopy(item, element_ptr, asr_type, module); + llvm_utils->deepcopy(item, element_ptr, asr_type, module, name2memidx); } void LLVMList::write_item(llvm::Value* list, llvm::Value* pos, @@ -1351,7 +1389,8 @@ namespace LFortran { llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Value* value, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type) { + ASR::ttype_t* value_asr_type, + std::map>& name2memidx) { llvm::Value* key_list = get_key_list(dict); llvm::Value* value_list = get_value_list(dict); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); @@ -1359,9 +1398,9 @@ namespace LFortran { this->resolve_collision(capacity, key_hash, key, key_list, key_mask, *module, key_asr_type); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm_utils->list_api->write_item(key_list, pos, key, - key_asr_type, false, *module); + key_asr_type, false, module, name2memidx); llvm_utils->list_api->write_item(value_list, pos, value, - value_asr_type, false, *module); + value_asr_type, false, module, name2memidx); llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, pos)); llvm::Value* is_slot_empty = builder->CreateICmpEQ(key_mask_value, @@ -1380,7 +1419,8 @@ namespace LFortran { llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Value* value, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type) { + ASR::ttype_t* value_asr_type, + std::map>& name2memidx) { llvm::Value* key_list = get_key_list(dict); llvm::Value* value_list = get_value_list(dict); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); @@ -1388,9 +1428,9 @@ namespace LFortran { this->resolve_collision(capacity, key_hash, key, key_list, key_mask, *module, key_asr_type); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm_utils->list_api->write_item(key_list, pos, key, - key_asr_type, false, *module); + key_asr_type, false, module, name2memidx); llvm_utils->list_api->write_item(value_list, pos, value, - value_asr_type, false, *module); + value_asr_type, false, module, name2memidx); llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, pos)); @@ -1420,7 +1460,8 @@ namespace LFortran { llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Value* value, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type) { + ASR::ttype_t* value_asr_type, + std::map>& name2memidx) { llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict)); llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash); @@ -1443,8 +1484,8 @@ namespace LFortran { llvm::Value* malloc_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), kv_struct_size); llvm::Value* new_kv_struct_i8 = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); llvm::Value* new_kv_struct = builder->CreateBitCast(new_kv_struct_i8, kv_struct_type->getPointerTo()); - llvm_utils->deepcopy(key, llvm_utils->create_gep(new_kv_struct, 0), key_asr_type, *module); - llvm_utils->deepcopy(value, llvm_utils->create_gep(new_kv_struct, 1), value_asr_type, *module); + llvm_utils->deepcopy(key, llvm_utils->create_gep(new_kv_struct, 0), key_asr_type, module, name2memidx); + llvm_utils->deepcopy(value, llvm_utils->create_gep(new_kv_struct, 1), value_asr_type, module, name2memidx); LLVM::CreateStore(*builder, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), llvm_utils->create_gep(new_kv_struct, 2)); @@ -1456,8 +1497,8 @@ namespace LFortran { llvm_utils->start_new_block(elseBB); { llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); - llvm_utils->deepcopy(key, llvm_utils->create_gep(kv_struct, 0), key_asr_type, *module); - llvm_utils->deepcopy(value, llvm_utils->create_gep(kv_struct, 1), value_asr_type, *module); + llvm_utils->deepcopy(key, llvm_utils->create_gep(kv_struct, 0), key_asr_type, module, name2memidx); + llvm_utils->deepcopy(value, llvm_utils->create_gep(kv_struct, 1), value_asr_type, module, name2memidx); } llvm_utils->start_new_block(mergeBB); llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); @@ -1724,8 +1765,9 @@ namespace LFortran { } void LLVMDict::rehash(llvm::Value* dict, llvm::Module* module, - ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type) { + ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type, + std::map>& name2memidx) { llvm::Value* capacity_ptr = get_pointer_to_capacity(dict); llvm::Value* old_capacity = LLVM::CreateLoad(*builder, capacity_ptr); llvm::Value* capacity = builder->CreateMul(old_capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), @@ -1802,10 +1844,10 @@ namespace LFortran { llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm::Value* key_dest = llvm_utils->list_api->read_item( new_key_list, pos, false, *module, true); - llvm_utils->deepcopy(key, key_dest, key_asr_type, *module); + llvm_utils->deepcopy(key, key_dest, key_asr_type, module, name2memidx); llvm::Value* value_dest = llvm_utils->list_api->read_item( new_value_list, pos, false, *module, true); - llvm_utils->deepcopy(value, value_dest, value_asr_type, *module); + llvm_utils->deepcopy(value, value_dest, value_asr_type, module, name2memidx); llvm::Value* linear_prob_happened = builder->CreateICmpNE(key_hash, pos); llvm::Value* set_max_2 = builder->CreateSelect(linear_prob_happened, @@ -1840,7 +1882,8 @@ namespace LFortran { void LLVMDictSeparateChaining::rehash( llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type) { + ASR::ttype_t* value_asr_type, + std::map>& name2memidx) { if( !are_iterators_set ) { old_capacity = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); old_occupancy = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); @@ -1919,7 +1962,7 @@ namespace LFortran { { llvm::Value* srci = llvm_utils->create_ptr_gep(old_key_value_pairs_value, itr); - write_key_value_pair_linked_list(srci, dict, capacity, key_asr_type, value_asr_type, module); + write_key_value_pair_linked_list(srci, dict, capacity, key_asr_type, value_asr_type, module, name2memidx); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -1965,7 +2008,8 @@ namespace LFortran { } void LLVMDict::rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx) { llvm::Function *fn = builder->GetInsertBlock()->getParent(); llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); @@ -1987,7 +2031,7 @@ namespace LFortran { builder->CreateCondBr(rehash_condition, thenBB, elseBB); builder->SetInsertPoint(thenBB); { - rehash(dict, module, key_asr_type, value_asr_type); + rehash(dict, module, key_asr_type, value_asr_type, name2memidx); } builder->CreateBr(mergeBB); @@ -1997,7 +2041,8 @@ namespace LFortran { void LLVMDictSeparateChaining::rehash_all_at_once_if_needed( llvm::Value* dict, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx) { llvm::Function *fn = builder->GetInsertBlock()->getParent(); llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); @@ -2018,7 +2063,7 @@ namespace LFortran { builder->CreateCondBr(rehash_condition, thenBB, elseBB); builder->SetInsertPoint(thenBB); { - rehash(dict, module, key_asr_type, value_asr_type); + rehash(dict, module, key_asr_type, value_asr_type, name2memidx); } builder->CreateBr(mergeBB); @@ -2028,22 +2073,24 @@ namespace LFortran { void LLVMDict::write_item(llvm::Value* dict, llvm::Value* key, llvm::Value* value, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { - rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx) { + rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type, name2memidx); llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module); this->resolve_collision_for_write(dict, key_hash, key, value, module, - key_asr_type, value_asr_type); + key_asr_type, value_asr_type, name2memidx); } void LLVMDictSeparateChaining::write_item(llvm::Value* dict, llvm::Value* key, llvm::Value* value, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { - rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx) { + rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type, name2memidx); llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module); this->resolve_collision_for_write(dict, key_hash, key, value, module, - key_asr_type, value_asr_type); + key_asr_type, value_asr_type, name2memidx); } llvm::Value* LLVMDict::read_item(llvm::Value* dict, llvm::Value* key, @@ -2227,7 +2274,7 @@ namespace LFortran { void LLVMList::resize_if_needed(llvm::Value* list, llvm::Value* n, llvm::Value* capacity, int32_t type_size, - llvm::Type* el_type, llvm::Module& module) { + llvm::Type* el_type, llvm::Module* module) { llvm::Value *cond = builder->CreateICmpEQ(n, capacity); llvm::Function *fn = builder->GetInsertBlock()->getParent(); llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); @@ -2244,7 +2291,7 @@ namespace LFortran { new_capacity); llvm::Value* copy_data_ptr = get_pointer_to_list_data(list); llvm::Value* copy_data = LLVM::CreateLoad(*builder, copy_data_ptr); - copy_data = LLVM::lfortran_realloc(context, module, *builder, + copy_data = LLVM::lfortran_realloc(context, *module, *builder, copy_data, arg_size); copy_data = builder->CreateBitCast(copy_data, el_type->getPointerTo()); builder->CreateStore(copy_data, copy_data_ptr); @@ -2262,7 +2309,8 @@ namespace LFortran { } void LLVMList::append(llvm::Value* list, llvm::Value* item, - ASR::ttype_t* asr_type, llvm::Module& module) { + ASR::ttype_t* asr_type, llvm::Module* module, + std::map>& name2memidx) { llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_current_capacity(list)); std::string type_code = ASRUtils::get_type_code(asr_type); @@ -2270,13 +2318,14 @@ namespace LFortran { llvm::Type* el_type = std::get<2>(typecode2listtype[type_code]); resize_if_needed(list, current_end_point, current_capacity, type_size, el_type, module); - write_item(list, current_end_point, item, asr_type, false, module); + write_item(list, current_end_point, item, asr_type, false, module, name2memidx); shift_end_point_by_one(list); } void LLVMList::insert_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, ASR::ttype_t* asr_type, - llvm::Module& module) { + llvm::Module* module, + std::map>& name2memidx) { std::string type_code = ASRUtils::get_type_code(asr_type); llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); @@ -2308,7 +2357,7 @@ namespace LFortran { // LLVMList should treat them as data members and create them // only if they are NULL llvm::AllocaInst *tmp_ptr = builder->CreateAlloca(el_type, nullptr); - LLVM::CreateStore(*builder, read_item(list, pos, false, module, false), tmp_ptr); + LLVM::CreateStore(*builder, read_item(list, pos, false, *module, false), tmp_ptr); llvm::Value* tmp = nullptr; // TODO: Should be created outside the user loop and not here. @@ -2337,8 +2386,8 @@ namespace LFortran { llvm::Value* next_index = builder->CreateAdd( LLVM::CreateLoad(*builder, pos_ptr), llvm::ConstantInt::get(context, llvm::APInt(32, 1))); - tmp = read_item(list, next_index, false, module, false); - write_item(list, next_index, LLVM::CreateLoad(*builder, tmp_ptr), false, module); + tmp = read_item(list, next_index, false, *module, false); + write_item(list, next_index, LLVM::CreateLoad(*builder, tmp_ptr), false, *module); LLVM::CreateStore(*builder, tmp, tmp_ptr); tmp = builder->CreateAdd( @@ -2351,7 +2400,7 @@ namespace LFortran { // end llvm_utils->start_new_block(loopend); - write_item(list, pos, item, asr_type, false, module); + write_item(list, pos, item, asr_type, false, module, name2memidx); shift_end_point_by_one(list); } @@ -2549,14 +2598,16 @@ namespace LFortran { } void LLVMTuple::tuple_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::Tuple_t* tuple_type, llvm::Module& module) { + ASR::Tuple_t* tuple_type, llvm::Module* module, + std::map>& name2memidx) { LFORTRAN_ASSERT(src->getType() == dest->getType()); for( size_t i = 0; i < tuple_type->n_type; i++ ) { llvm::Value* src_item = read_item(src, i, LLVM::is_llvm_struct( tuple_type->m_type[i])); llvm::Value* dest_item_ptr = read_item(dest, i, true); llvm_utils->deepcopy(src_item, dest_item_ptr, - tuple_type->m_type[i], module); + tuple_type->m_type[i], module, + name2memidx); } } diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 11822c76df..2e89424971 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -12,6 +12,10 @@ namespace LFortran { + namespace LLVMArrUtils { + class Descriptor; + } + static inline void printf(llvm::LLVMContext &context, llvm::Module &module, llvm::IRBuilder<> &builder, const std::vector &args) { @@ -93,6 +97,7 @@ namespace LFortran { LLVMTuple* tuple_api; LLVMList* list_api; LLVMDictInterface* dict_api; + LLVMArrUtils::Descriptor* arr_api; LLVMUtils(llvm::LLVMContext& context, llvm::IRBuilder<>* _builder); @@ -120,7 +125,8 @@ namespace LFortran { void reset_iterators(); void deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::ttype_t* asr_type, llvm::Module& module); + ASR::ttype_t* asr_type, llvm::Module* module, + std::map>& name2memidx); }; // LLVMUtils @@ -135,7 +141,7 @@ namespace LFortran { void resize_if_needed(llvm::Value* list, llvm::Value* n, llvm::Value* capacity, int32_t type_size, - llvm::Type* el_type, llvm::Module& module); + llvm::Type* el_type, llvm::Module* module); void shift_end_point_by_one(llvm::Value* list); @@ -162,12 +168,12 @@ namespace LFortran { llvm::Value* get_pointer_to_current_capacity(llvm::Value* list); void list_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::List_t* list_type, - llvm::Module& module); + ASR::List_t* list_type, llvm::Module* module, + std::map>& name2memidx); void list_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::ttype_t* element_type, - llvm::Module& module); + ASR::ttype_t* element_type, llvm::Module* module, + std::map>& name2memidx); llvm::Value* read_item(llvm::Value* list, llvm::Value* pos, bool enable_bounds_checking, @@ -179,19 +185,22 @@ namespace LFortran { llvm::Module& module); void write_item(llvm::Value* list, llvm::Value* pos, - llvm::Value* item, ASR::ttype_t* asr_type, - bool enable_bounds_checking, llvm::Module& module); + llvm::Value* item, ASR::ttype_t* asr_type, + bool enable_bounds_checking, llvm::Module* module, + std::map>& name2memidx); void write_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, bool enable_bounds_checking, llvm::Module& module); void append(llvm::Value* list, llvm::Value* item, - ASR::ttype_t* asr_type, llvm::Module& module); + ASR::ttype_t* asr_type, llvm::Module* module, + std::map>& name2memidx); void insert_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, ASR::ttype_t* asr_type, - llvm::Module& module); + llvm::Module* module, + std::map>& name2memidx); void remove(llvm::Value* list, llvm::Value* item, ASR::ttype_t* item_type, llvm::Module& module); @@ -232,7 +241,8 @@ namespace LFortran { bool get_pointer=false); void tuple_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::Tuple_t* type_code, llvm::Module& module); + ASR::Tuple_t* type_code, llvm::Module* module, + std::map>& name2memidx); llvm::Value* check_tuple_equality(llvm::Value* t1, llvm::Value* t2, ASR::Tuple_t* tuple_type, llvm::LLVMContext& context, @@ -298,7 +308,8 @@ namespace LFortran { void resolve_collision_for_write(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Value* value, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type) = 0; + ASR::ttype_t* value_asr_type, + std::map>& name2memidx) = 0; virtual llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash, @@ -307,19 +318,21 @@ namespace LFortran { virtual void rehash(llvm::Value* dict, llvm::Module* module, - ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type) = 0; + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx) = 0; virtual void rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type) = 0; + ASR::ttype_t* value_asr_type, + std::map>& name2memidx) = 0; virtual void write_item(llvm::Value* dict, llvm::Value* key, llvm::Value* value, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) = 0; + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx) = 0; virtual llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, @@ -339,7 +352,8 @@ namespace LFortran { virtual void dict_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::Dict_t* dict_type, llvm::Module* module) = 0; + ASR::Dict_t* dict_type, llvm::Module* module, + std::map>& name2memidx) = 0; virtual llvm::Value* len(llvm::Value* dict) = 0; @@ -386,24 +400,27 @@ namespace LFortran { void resolve_collision_for_write(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Value* value, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type); + ASR::ttype_t* value_asr_type, + std::map>& name2memidx); llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); void rehash(llvm::Value* dict, llvm::Module* module, - ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx); void rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type); + ASR::ttype_t* value_asr_type, + std::map>& name2memidx); void write_item(llvm::Value* dict, llvm::Value* key, llvm::Value* value, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx); llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* key_asr_type, @@ -417,7 +434,8 @@ namespace LFortran { llvm::Value* get_pointer_to_keymask(llvm::Value* dict); void dict_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::Dict_t* dict_type, llvm::Module* module); + ASR::Dict_t* dict_type, llvm::Module* module, + std::map>& name2memidx); llvm::Value* len(llvm::Value* dict); @@ -440,7 +458,8 @@ namespace LFortran { void resolve_collision_for_write(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Value* value, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type); + ASR::ttype_t* value_asr_type, + std::map>& name2memidx); llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, @@ -464,11 +483,11 @@ namespace LFortran { void deepcopy_key_value_pair_linked_list(llvm::Value* srci, llvm::Value* desti, llvm::Value* dest_key_value_pairs, llvm::Value* src_capacity, ASR::Dict_t* dict_type, - llvm::Module* module); + llvm::Module* module, std::map>& name2memidx); void write_key_value_pair_linked_list(llvm::Value* kv_ll, llvm::Value* dict, llvm::Value* capacity, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, - llvm::Module* module); + llvm::Module* module, std::map>& name2memidx); void resolve_collision(llvm::Value* capacity, llvm::Value* key_hash, llvm::Value* key, llvm::Value* key_value_pair_linked_list, @@ -507,24 +526,27 @@ namespace LFortran { void resolve_collision_for_write(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Value* value, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type); + ASR::ttype_t* value_asr_type, + std::map>& name2memidx); llvm::Value* resolve_collision_for_read(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); void rehash(llvm::Value* dict, llvm::Module* module, - ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx); void rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, - ASR::ttype_t* value_asr_type); + ASR::ttype_t* value_asr_type, + std::map>& name2memidx); void write_item(llvm::Value* dict, llvm::Value* key, llvm::Value* value, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, + std::map>& name2memidx); llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, @@ -537,7 +559,8 @@ namespace LFortran { llvm::Value* get_pointer_to_keymask(llvm::Value* dict); void dict_deepcopy(llvm::Value* src, llvm::Value* dest, - ASR::Dict_t* dict_type, llvm::Module* module); + ASR::Dict_t* dict_type, llvm::Module* module, + std::map>& name2memidx); llvm::Value* len(llvm::Value* dict); diff --git a/src/libasr/pass/pass_manager.h b/src/libasr/pass/pass_manager.h index df400ad196..65f442d709 100644 --- a/src/libasr/pass/pass_manager.h +++ b/src/libasr/pass/pass_manager.h @@ -39,6 +39,7 @@ #include #include #include +#include #include #include @@ -76,7 +77,8 @@ namespace LCompilers { {"loop_vectorise", &LFortran::pass_loop_vectorise}, {"array_dim_intrinsics_update", &LFortran::pass_update_array_dim_intrinsic_calls}, {"pass_list_expr", &LFortran::pass_list_expr}, - {"pass_array_by_data", &LFortran::pass_array_by_data} + {"pass_array_by_data", &LFortran::pass_array_by_data}, + {"subroutine_from_function", &LFortran::pass_create_subroutine_from_function} }; bool is_fast; @@ -100,6 +102,7 @@ namespace LCompilers { "pass_array_by_data", "pass_list_expr", "arr_slice", + "subroutine_from_function", "array_op", "print_arr", "print_list", @@ -117,6 +120,7 @@ namespace LCompilers { "implied_do_loops", "pass_array_by_data", "arr_slice", + "subroutine_from_function", "array_op", "print_arr", "print_list", diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index dc1950816f..23696e7efb 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -71,6 +71,10 @@ namespace LFortran { Vec replace_doloop(Allocator &al, const ASR::DoLoop_t &loop, int comp=-1); + static inline bool is_aggregate_type(ASR::expr_t* var) { + return ASR::is_a(*ASRUtils::expr_type(var)); + } + template class PassVisitor: public ASR::BaseWalkVisitor { diff --git a/src/libasr/pass/subroutine_from_function.cpp b/src/libasr/pass/subroutine_from_function.cpp new file mode 100644 index 0000000000..a2dc1b96c8 --- /dev/null +++ b/src/libasr/pass/subroutine_from_function.cpp @@ -0,0 +1,220 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + + +namespace LFortran { + +using ASR::down_cast; +using ASR::is_a; + +class CreateFunctionFromSubroutine: public PassUtils::PassVisitor { + + public: + + CreateFunctionFromSubroutine(Allocator &al_) : + PassVisitor(al_, nullptr) + { + pass_result.reserve(al, 1); + } + + ASR::symbol_t* create_subroutine_from_function(ASR::Function_t* s) { + for( auto& s_item: s->m_symtab->get_scope() ) { + ASR::symbol_t* curr_sym = s_item.second; + if( curr_sym->type == ASR::symbolType::Variable ) { + ASR::Variable_t* var = ASR::down_cast(curr_sym); + if( var->m_intent == ASR::intentType::Unspecified ) { + var->m_intent = ASR::intentType::In; + } else if( var->m_intent == ASR::intentType::ReturnVar ) { + var->m_intent = ASR::intentType::Out; + } + } + } + Vec a_args; + a_args.reserve(al, s->n_args + 1); + for( size_t i = 0; i < s->n_args; i++ ) { + a_args.push_back(al, s->m_args[i]); + } + LFORTRAN_ASSERT(s->m_return_var) + a_args.push_back(al, s->m_return_var); + ASR::asr_t* s_sub_asr = ASR::make_Function_t(al, s->base.base.loc, + s->m_symtab, + s->m_name, s->m_dependencies, s->n_dependencies, + a_args.p, a_args.size(), s->m_body, s->n_body, + nullptr, + s->m_abi, s->m_access, s->m_deftype, nullptr, false, false, + false, s->m_inline, s->m_static, + s->m_type_params, s->n_type_params, + s->m_restrictions, s->n_restrictions, + s->m_is_restriction); + ASR::symbol_t* s_sub = ASR::down_cast(s_sub_asr); + return s_sub; + } + + void visit_TranslationUnit(const ASR::TranslationUnit_t &x) { + std::vector> replace_vec; + // Transform functions returning arrays to subroutines + for (auto &item : x.m_global_scope->get_scope()) { + if (is_a(*item.second)) { + ASR::Function_t *s = down_cast(item.second); + if (s->m_return_var) { + /* + * A function which returns a aggregate type like array, struct will be converted + * to a subroutine with the destination array as the last + * argument. This helps in avoiding deep copies and the + * destination memory directly gets filled inside the subroutine. + */ + if( PassUtils::is_aggregate_type(s->m_return_var) ) { + ASR::symbol_t* s_sub = create_subroutine_from_function(s); + replace_vec.push_back(std::make_pair(item.first, s_sub)); + } + } + } + } + + // FIXME: this is a hack, we need to pass in a non-const `x`, + // which requires to generate a TransformVisitor. + ASR::TranslationUnit_t &xx = const_cast(x); + // Updating the symbol table so that now the name + // of the function (which returned array) now points + // to the newly created subroutine. + for( auto& item: replace_vec ) { + xx.m_global_scope->add_symbol(item.first, item.second); + } + + // Now visit everything else + for (auto &item : x.m_global_scope->get_scope()) { + this->visit_symbol(*item.second); + } + } + + void visit_Program(const ASR::Program_t &x) { + std::vector > replace_vec; + // FIXME: this is a hack, we need to pass in a non-const `x`, + // which requires to generate a TransformVisitor. + ASR::Program_t &xx = const_cast(x); + current_scope = xx.m_symtab; + + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + ASR::Function_t *s = ASR::down_cast(item.second); + if (s->m_return_var) { + /* + * A function which returns an array will be converted + * to a subroutine with the destination array as the last + * argument. This helps in avoiding deep copies and the + * destination memory directly gets filled inside the subroutine. + */ + if( PassUtils::is_aggregate_type(s->m_return_var) ) { + ASR::symbol_t* s_sub = create_subroutine_from_function(s); + replace_vec.push_back(std::make_pair(item.first, s_sub)); + } + } + } + } + + // Updating the symbol table so that now the name + // of the function (which returned array) now points + // to the newly created subroutine. + for( auto& item: replace_vec ) { + current_scope->add_symbol(item.first, item.second); + } + + for (auto &item : x.m_symtab->get_scope()) { + if (is_a(*item.second)) { + ASR::AssociateBlock_t *s = ASR::down_cast(item.second); + visit_AssociateBlock(*s); + } + if (is_a(*item.second)) { + ASR::Function_t *s = ASR::down_cast(item.second); + visit_Function(*s); + } + } + + current_scope = xx.m_symtab; + transform_stmts(xx.m_body, xx.n_body); + + } + +}; + +class ReplaceFunctionCallWithSubroutineCall: public PassUtils::PassVisitor { + + private: + + ASR::expr_t *result_var; + + public: + + ReplaceFunctionCallWithSubroutineCall(Allocator& al_): + PassVisitor(al_, nullptr), result_var(nullptr) + { + pass_result.reserve(al, 1); + } + + void visit_Assignment(const ASR::Assignment_t& x) { + if( PassUtils::is_aggregate_type(x.m_target) ) { + result_var = x.m_target; + this->visit_expr(*(x.m_value)); + } + result_var = nullptr; + } + + void visit_FunctionCall(const ASR::FunctionCall_t& x) { + std::string x_name; + if( x.m_name->type == ASR::symbolType::ExternalSymbol ) { + x_name = down_cast(x.m_name)->m_name; + } else if( x.m_name->type == ASR::symbolType::Function ) { + x_name = down_cast(x.m_name)->m_name; + } + // The following checks if the name of a function actually + // points to a subroutine. If true this would mean that the + // original function returned an array and is now a subroutine. + // So the current function call will be converted to a subroutine + // call. In short, this check acts as a signal whether to convert + // a function call to a subroutine call. + if (current_scope == nullptr) { + return ; + } + + ASR::symbol_t *sub = current_scope->resolve_symbol(x_name); + if (sub && ASR::is_a(*sub) + && ASR::down_cast(sub)->m_return_var == nullptr) { + LFORTRAN_ASSERT(result_var != nullptr); + Vec s_args; + s_args.reserve(al, x.n_args + 1); + for( size_t i = 0; i < x.n_args; i++ ) { + s_args.push_back(al, x.m_args[i]); + } + ASR::call_arg_t result_arg; + result_arg.loc = result_var->base.loc; + result_arg.m_value = result_var; + s_args.push_back(al, result_arg); + ASR::stmt_t* subrout_call = LFortran::ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, + sub, nullptr, + s_args.p, s_args.size(), nullptr)); + pass_result.push_back(al, subrout_call); + } + result_var = nullptr; + } + +}; + +void pass_create_subroutine_from_function(Allocator &al, ASR::TranslationUnit_t &unit, + const LCompilers::PassOptions& /*pass_options*/) { + CreateFunctionFromSubroutine v(al); + v.visit_TranslationUnit(unit); + ReplaceFunctionCallWithSubroutineCall u(al); + u.visit_TranslationUnit(unit); + LFORTRAN_ASSERT(asr_verify(unit)); +} + + +} // namespace LFortran diff --git a/src/libasr/pass/subroutine_from_function.h b/src/libasr/pass/subroutine_from_function.h new file mode 100644 index 0000000000..9dc0afa709 --- /dev/null +++ b/src/libasr/pass/subroutine_from_function.h @@ -0,0 +1,14 @@ +#ifndef LIBASR_PASS_SUBROUTINE_FROM_FUNCTION_H +#define LIBASR_PASS_SUBROUTINE_FROM_FUNCTION_H + +#include +#include + +namespace LFortran { + + void pass_create_subroutine_from_function(Allocator &al, ASR::TranslationUnit_t &unit, + const LCompilers::PassOptions& pass_options); + +} // namespace LFortran + +#endif // LIBASR_PASS_SUBROUTINE_FROM_FUNCTION_H