diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index e61654d7d1..8fa1141988 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -151,6 +151,9 @@ RUN(NAME test_list_01 LABELS cpython llvm) RUN(NAME test_list_02 LABELS cpython llvm) RUN(NAME test_list_03 LABELS cpython llvm) RUN(NAME test_list_04 LABELS cpython llvm) +RUN(NAME test_list_05 LABELS cpython llvm) +RUN(NAME test_list_06 LABELS cpython llvm) +RUN(NAME test_list_07 LABELS cpython llvm) RUN(NAME test_tuple_01 LABELS cpython llvm) RUN(NAME test_tuple_02 LABELS cpython llvm) RUN(NAME modules_01 LABELS cpython llvm) diff --git a/integration_tests/test_list_05.py b/integration_tests/test_list_05.py new file mode 100644 index 0000000000..1667b22ca3 --- /dev/null +++ b/integration_tests/test_list_05.py @@ -0,0 +1,56 @@ +from ltypes import i32, f64 + +def check_list_of_tuples(l: list[tuple[i32, f64, str]], sign: i32): + size: i32 = len(l) + i: i32 + t: tuple[i32, f64, str] + string: str + + for i in range(size): + t = l[i] + string = str(sign * i) + "_str" + assert t[0] == sign * i + assert l[i][0] == sign * i + + assert t[1] == float(sign * i) + assert l[i][1] == float(sign * i) + + assert t[2] == string + assert l[i][2] == string + +def test_list_of_tuples(): + l1: list[tuple[i32, f64, str]] = [] + t: tuple[i32, f64, str] + size: i32 = 20 + i: i32 + string: str + + for i in range(size): + t = (i, float(i), str(i) + "_str") + l1.append(t) + + check_list_of_tuples(l1, 1) + + for i in range(size//2): + l1.remove(l1[len(l1) - 1]) + + t = l1[len(l1) - 1] + assert t[0] == size//2 - 1 + assert t[1] == size//2 - 1 + assert t[2] == str(size//2 - 1) + "_str" + + for i in range(size//2, size): + string = str(i) + "_str" + t = (i, float(i), string) + l1.insert(i, t) + + check_list_of_tuples(l1, 1) + + for i in range(size): + string = str(-i) + "_str" + t = (-i, float(-i), string) + l1[i] = t + + check_list_of_tuples(l1, -1) + +test_list_of_tuples() diff --git a/integration_tests/test_list_06.py b/integration_tests/test_list_06.py new file mode 100644 index 0000000000..9081430c89 --- /dev/null +++ b/integration_tests/test_list_06.py @@ -0,0 +1,64 @@ +from ltypes import i32, f64 +from copy import deepcopy + +def check_mat_and_vec(mat: list[list[f64]], vec: list[f64]): + rows: i32 = len(mat) + cols: i32 = len(vec) + i: i32 + j: i32 + + for i in range(rows): + for j in range(cols): + assert mat[i][j] == float(i + j) + + for i in range(cols): + assert vec[i] == 2 * float(i) + +def test_list_of_lists(): + tensors: list[list[list[list[f64]]]] = [] + tensor: list[list[list[f64]]] = [] + mat: list[list[f64]] = [] + vec: list[f64] = [] + rows: i32 = 10 + cols: i32 = 5 + i: i32 + j: i32 + k: i32 + l: i32 + + for i in range(rows): + for j in range(cols): + vec.append(float(i + j)) + mat.append(deepcopy(vec)) + vec.clear() + + for i in range(cols): + vec.append(2 * float(i)) + + check_mat_and_vec(mat, vec) + + for k in range(rows): + tensor.append(deepcopy(mat)) + for i in range(rows): + for j in range(cols): + mat[i][j] += float(1) + + for k in range(rows): + for i in range(rows): + for j in range(cols): + assert mat[i][j] - tensor[k][i][j] == rows - k + + for l in range(2 * rows): + tensors.append(deepcopy(tensor)) + for i in range(rows): + for j in range(rows): + for k in range(cols): + tensor[i][j][k] += float(1) + + for l in range(2 * rows): + for i in range(rows): + for j in range(rows): + for k in range(cols): + assert tensor[i][j][k] - tensors[l][i][j][k] == 2 * rows - l + +test_list_of_lists() diff --git a/integration_tests/test_list_07.py b/integration_tests/test_list_07.py new file mode 100644 index 0000000000..157353d03c --- /dev/null +++ b/integration_tests/test_list_07.py @@ -0,0 +1,67 @@ +from ltypes import c64, i32 +from copy import deepcopy + +def test_tuple_with_lists(): + mat: list[list[c64]] = [] + vec: list[c64] = [] + tensor: tuple[list[list[c64]], list[c64]] + tensors: list[tuple[list[list[c64]], list[c64]]] = [] + i: i32 + j: i32 + k: i32 + l: i32 + rows: i32 = 10 + cols: i32 = 5 + + for i in range(rows): + for j in range(cols): + vec.append(complex(i + j, 0)) + mat.append(deepcopy(vec)) + vec.clear() + + for i in range(cols): + vec.append(complex(2 * i, 0)) + + for i in range(rows): + for j in range(cols): + assert mat[i][j] - vec[j] == i - j + + tensor = (deepcopy(mat), deepcopy(vec)) + + for i in range(rows): + for j in range(cols): + mat[i][j] += complex(0, 3.0) + + for i in range(cols): + vec[i] += complex(0, 2.0) + + for i in range(rows): + for j in range(cols): + assert tensor[0][i][j] - mat[i][j] == -complex(0, 3.0) + + for i in range(cols): + assert tensor[1][i] - vec[i] == -complex(0, 2.0) + + tensor = (deepcopy(mat), deepcopy(vec)) + + for k in range(2 * rows): + tensors.append(deepcopy(tensor)) + for i in range(rows): + for j in range(cols): + mat[i][j] += complex(1.0, 2.0) + + for i in range(cols): + vec[i] += complex(1.0, 2.0) + + tensor = (deepcopy(mat), deepcopy(vec)) + + for k in range(2 * rows): + for i in range(rows): + for j in range(cols): + assert tensors[k][0][i][j] - mat[i][j] == -(2 * rows - k) * complex(1.0, 2.0) + + for k in range(2 * rows): + for i in range(cols): + assert tensors[k][1][i] - vec[i] == -(2 * rows - k) * complex(1.0, 2.0) + +test_tuple_with_lists() diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 643d012d4e..4e96ee0471 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -206,6 +206,7 @@ stmt | SetRemove(expr a, expr ele) | ListInsert(expr a, expr pos, expr ele) | ListRemove(expr a, expr ele) + | ListClear(expr a) | DictInsert(expr a, expr key, expr value) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 19942f59a9..f3fb73441c 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -244,6 +244,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ptr_loads(2), is_assignment_target(false) { + llvm_utils->tuple_api = tuple_api.get(); + llvm_utils->list_api = list_api.get(); } llvm::Value* CreateLoad(llvm::Value *x) { @@ -1147,7 +1149,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*x.m_args[i]); llvm::Value* item = tmp; llvm::Value* pos = llvm::ConstantInt::get(context, llvm::APInt(32, i)); - list_api->write_item(const_list, pos, item); + list_api->write_item(const_list, pos, item, list_type->m_type, *module); } ptr_loads = ptr_loads_copy; tmp = const_list; @@ -1203,31 +1205,34 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_ListAppend(const ASR::ListAppend_t& x) { + ASR::List_t* asr_list = ASR::down_cast(ASRUtils::expr_type(x.m_a)); uint64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; this->visit_expr(*x.m_a); - ptr_loads = ptr_loads_copy; llvm::Value* plist = tmp; + ptr_loads = !LLVM::is_llvm_struct(asr_list->m_type); this->visit_expr_wrapper(x.m_ele, true); llvm::Value *item = tmp; + ptr_loads = ptr_loads_copy; - ASR::List_t* asr_list = ASR::down_cast(ASRUtils::expr_type(x.m_a)); - std::string type_code = ASRUtils::get_type_code(asr_list->m_type); - list_api->append(plist, item, *module, type_code); + list_api->append(plist, item, asr_list->m_type, *module); } void visit_ListItem(const ASR::ListItem_t& x) { + ASR::ttype_t* el_type = ASRUtils::get_contained_type( + ASRUtils::expr_type(x.m_a)); uint64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; this->visit_expr(*x.m_a); - ptr_loads = ptr_loads_copy; llvm::Value* plist = tmp; + ptr_loads = 1; this->visit_expr_wrapper(x.m_pos, true); + ptr_loads = ptr_loads_copy; llvm::Value *pos = tmp; - tmp = list_api->read_item(plist, pos); + tmp = list_api->read_item(plist, pos, LLVM::is_llvm_struct(el_type)); } void visit_ListLen(const ASR::ListLen_t& x) { @@ -1244,23 +1249,23 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_ListInsert(const ASR::ListInsert_t& x) { + ASR::List_t* asr_list = ASR::down_cast( + ASRUtils::expr_type(x.m_a)); uint64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; this->visit_expr(*x.m_a); - ptr_loads = ptr_loads_copy; llvm::Value* plist = tmp; + ptr_loads = 1; this->visit_expr_wrapper(x.m_pos, true); llvm::Value *pos = tmp; + ptr_loads = !LLVM::is_llvm_struct(asr_list->m_type); this->visit_expr_wrapper(x.m_ele, true); llvm::Value *item = tmp; + ptr_loads = ptr_loads_copy; - ASR::List_t* asr_list = ASR::down_cast( - ASRUtils::expr_type(x.m_a)); - std::string type_code = ASRUtils::get_type_code(asr_list->m_type); - - list_api->insert_item(plist, pos, item, *module, type_code); + list_api->insert_item(plist, pos, item, asr_list->m_type, *module); } void visit_ListRemove(const ASR::ListRemove_t& x) { @@ -1268,12 +1273,23 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor uint64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; this->visit_expr(*x.m_a); - ptr_loads = ptr_loads_copy; llvm::Value* plist = tmp; + ptr_loads = !LLVM::is_llvm_struct(asr_el_type); this->visit_expr_wrapper(x.m_ele, true); + ptr_loads = ptr_loads_copy; llvm::Value *item = tmp; - list_api->remove(plist, item, asr_el_type->type, *module); + list_api->remove(plist, item, asr_el_type, *module); + } + + void visit_ListClear(const ASR::ListClear_t& x) { + uint64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_a); + llvm::Value* plist = tmp; + ptr_loads = ptr_loads_copy; + + list_api->list_clear(plist); } void visit_TupleLen(const ASR::TupleLen_t& x) { @@ -1291,7 +1307,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr_wrapper(x.m_pos, true); llvm::Value *pos = tmp; - tmp = tuple_api->read_item(ptuple, pos); + tmp = tuple_api->read_item(ptuple, pos, LLVM::is_llvm_struct(x.m_type)); } void visit_ArrayItem(const ASR::ArrayItem_t& x) { @@ -1838,7 +1854,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor a_kind); std::string el_type_code = ASRUtils::get_type_code(asr_list->m_type); int32_t type_size = -1; - if( ASR::is_a(*asr_list->m_type) ) { + if( LLVM::is_llvm_struct(asr_list->m_type) || + ASR::is_a(*asr_list->m_type) || + ASR::is_a(*asr_list->m_type) ) { llvm::DataLayout data_layout(module.get()); type_size = data_layout.getTypeAllocSize(el_llvm_type); } else { @@ -2218,8 +2236,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor is_malloc_array_type, is_list, m_dims, n_dims, a_kind); + int32_t type_size = -1; + if( LLVM::is_llvm_struct(asr_list->m_type) || + ASR::is_a(*asr_list->m_type) || + ASR::is_a(*asr_list->m_type) ) { + llvm::DataLayout data_layout(module.get()); + type_size = data_layout.getTypeAllocSize(el_llvm_type); + } else { + type_size = a_kind; + } std::string el_type_code = ASRUtils::get_type_code(asr_list->m_type); - type = list_api->get_list_type(el_llvm_type, el_type_code, a_kind)->getPointerTo(); + type = list_api->get_list_type(el_llvm_type, el_type_code, type_size)->getPointerTo(); break; } default : @@ -2943,7 +2970,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::List_t* value_asr_list = ASR::down_cast( ASRUtils::expr_type(x.m_value)); std::string value_type_code = ASRUtils::get_type_code(value_asr_list->m_type); - list_api->list_deepcopy(value_list, target_list, value_type_code, *module); + list_api->list_deepcopy(value_list, target_list, + value_asr_list, *module); return ; } else if( is_target_tuple && is_value_tuple ) { uint64_t ptr_loads_copy = ptr_loads; @@ -2981,7 +3009,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ASR::Tuple_t* value_tuple_type = ASR::down_cast(asr_value_type); std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type, value_tuple_type->n_type); - tuple_api->tuple_deepcopy(value_tuple, target_tuple, type_code); + tuple_api->tuple_deepcopy(value_tuple, target_tuple, + value_tuple_type, *module); } } return ; diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 62e3671236..23484baa0d 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace LFortran { @@ -179,6 +180,59 @@ namespace LFortran { return builder->CreateCall(fn, args); } + llvm::Value* LLVMUtils::is_equal_by_value(llvm::Value* left, llvm::Value* right, + llvm::Module& module, ASR::ttype_t* asr_type) { + switch( asr_type->type ) { + case ASR::ttypeType::Integer: { + return builder->CreateICmpEQ(left, right); + }; + case ASR::ttypeType::Real: { + return builder->CreateFCmpOEQ(left, right); + } + case ASR::ttypeType::Character: { + return lfortran_str_cmp(left, right, "_lpython_str_compare_eq", + module); + } + case ASR::ttypeType::Tuple: { + ASR::Tuple_t* tuple_type = ASR::down_cast(asr_type); + return tuple_api->check_tuple_equality(left, right, tuple_type, context, + builder, module); + } + default: { + throw LCompilersException("LLVMUtils::is_equal_by_value isn't implemented for " + + ASRUtils::type_to_str_python(asr_type)); + } + } + } + + void LLVMUtils::deepcopy(llvm::Value* src, llvm::Value* dest, + ASR::ttype_t* asr_type, llvm::Module& module) { + switch( asr_type->type ) { + case ASR::ttypeType::Integer: + case ASR::ttypeType::Real: + case ASR::ttypeType::Character: + case ASR::ttypeType::Logical: + case ASR::ttypeType::Complex: { + LLVM::CreateStore(*builder, src, dest); + break ; + }; + case ASR::ttypeType::Tuple: { + ASR::Tuple_t* tuple_type = ASR::down_cast(asr_type); + tuple_api->tuple_deepcopy(src, dest, tuple_type, module); + break ; + } + case ASR::ttypeType::List: { + ASR::List_t* list_type = ASR::down_cast(asr_type); + list_api->list_deepcopy(src, dest, list_type, module); + break ; + } + default: { + throw LCompilersException("LLVMUtils::deepcopy isn't implemented for " + + ASRUtils::type_to_str_python(asr_type)); + } + } + } + LLVMList::LLVMList(llvm::LLVMContext& context_, LLVMUtils* llvm_utils_, llvm::IRBuilder<>* builder_): @@ -233,9 +287,9 @@ namespace LFortran { } void LLVMList::list_deepcopy(llvm::Value* src, llvm::Value* dest, - std::string& src_type_code, - llvm::Module& module) { + ASR::List_t* list_type, llvm::Module& module) { LFORTRAN_ASSERT(src->getType() == dest->getType()); + std::string src_type_code = ASRUtils::get_type_code(list_type->m_type); llvm::Value* src_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(src)); llvm::Value* src_capacity = LLVM::CreateLoad(*builder, get_pointer_to_current_capacity(src)); llvm::Value* dest_end_point_ptr = get_pointer_to_current_end_point(dest); @@ -250,15 +304,73 @@ namespace LFortran { arg_size); llvm::Type* el_type = std::get<2>(typecode2listtype[src_type_code]); copy_data = builder->CreateBitCast(copy_data, el_type->getPointerTo()); - builder->CreateMemCpy(copy_data, llvm::MaybeAlign(), src_data, - llvm::MaybeAlign(), arg_size); - builder->CreateStore(copy_data, get_pointer_to_list_data(dest)); + + // We consider the case when the element type of a list is defined by a struct + // which may also contain non-trivial structs (such as in case of list[list[f64]], + // list[tuple[f64]]). We need to make sure that all the data inside those structs + // is deepcopied and not just the address of the first element of those structs. + // Hence we dive deeper into the lowest level of nested types and deepcopy everything + // properly. If we don't consider this case then the data only from first level of nested types + // will be deep copied and rest will be shallow copied. The importance of this case + // can be figured out by goind through, integration_tests/test_list_06.py and + // integration_tests/test_list_07.py. + if( LLVM::is_llvm_struct(list_type->m_type) ) { + builder->CreateStore(copy_data, get_pointer_to_list_data(dest)); + llvm::AllocaInst *pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), + nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), pos_ptr); + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT( + src_end_point, + LLVM::CreateLoad(*builder, pos_ptr)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* srci = read_item(src, pos, true); + llvm::Value* desti = read_item(dest, pos, true); + llvm_utils->deepcopy(srci, desti, list_type->m_type, module); + llvm::Value* tmp = builder->CreateAdd( + pos, + llvm::ConstantInt::get(context, llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, tmp, pos_ptr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + } else { + builder->CreateMemCpy(copy_data, llvm::MaybeAlign(), src_data, + llvm::MaybeAlign(), arg_size); + builder->CreateStore(copy_data, get_pointer_to_list_data(dest)); + } + } + + void LLVMList::write_item(llvm::Value* list, llvm::Value* pos, + llvm::Value* item, ASR::ttype_t* asr_type, + llvm::Module& module) { + llvm::Value* list_data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list)); + llvm::Value* element_ptr = llvm_utils->create_ptr_gep(list_data, pos); + llvm_utils->deepcopy(item, element_ptr, asr_type, module); } - void LLVMList::write_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item) { + void LLVMList::write_item(llvm::Value* list, llvm::Value* pos, + llvm::Value* item) { llvm::Value* list_data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list)); llvm::Value* element_ptr = llvm_utils->create_ptr_gep(list_data, pos); - builder->CreateStore(item, element_ptr); + LLVM::CreateStore(*builder, item, element_ptr); } llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool get_pointer) { @@ -311,21 +423,22 @@ namespace LFortran { } void LLVMList::append(llvm::Value* list, llvm::Value* item, - llvm::Module& module, - std::string& type_code) { + ASR::ttype_t* asr_type, llvm::Module& module) { llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_current_capacity(list)); + std::string type_code = ASRUtils::get_type_code(asr_type); int type_size = std::get<1>(typecode2listtype[type_code]); llvm::Type* el_type = std::get<2>(typecode2listtype[type_code]); resize_if_needed(list, current_end_point, current_capacity, type_size, el_type, module); - write_item(list, current_end_point, item); + write_item(list, current_end_point, item, asr_type, module); shift_end_point_by_one(list); } void LLVMList::insert_item(llvm::Value* list, llvm::Value* pos, - llvm::Value* item, llvm::Module& module, - std::string& type_code) { + llvm::Value* item, ASR::ttype_t* asr_type, + llvm::Module& module) { + std::string type_code = ASRUtils::get_type_code(asr_type); llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); llvm::Value* current_capacity = LLVM::CreateLoad(*builder, @@ -380,7 +493,7 @@ namespace LFortran { LLVM::CreateLoad(*builder, pos_ptr), llvm::ConstantInt::get(context, llvm::APInt(32, 1))); tmp = read_item(list, next_index, false); - write_item(list, next_index, LLVM::CreateLoad(*builder, tmp_ptr)); + write_item(list, next_index, LLVM::CreateLoad(*builder, tmp_ptr)); LLVM::CreateStore(*builder, tmp, tmp_ptr); tmp = builder->CreateAdd( @@ -393,12 +506,12 @@ namespace LFortran { // end llvm_utils->start_new_block(loopend); - write_item(list, pos, item); + write_item(list, pos, item, asr_type, module); shift_end_point_by_one(list); } llvm::Value* LLVMList::find_item_position(llvm::Value* list, - llvm::Value* item, ASR::ttypeType item_type, llvm::Module& module) { + llvm::Value* item, ASR::ttype_t* item_type, llvm::Module& module) { llvm::Type* pos_type = llvm::Type::getInt32Ty(context); llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); @@ -425,15 +538,13 @@ namespace LFortran { // head llvm_utils->start_new_block(loophead); { - llvm::Value* is_item_not_equal = nullptr; - llvm::Value* left_arg = read_item(list, LLVM::CreateLoad(*builder, i), false); - if( item_type == ASR::ttypeType::Character ) { - is_item_not_equal = llvm_utils->lfortran_str_cmp(left_arg, item, - "_lpython_str_compare_noteq", - module); - } else { - is_item_not_equal = builder->CreateICmpNE(left_arg, item); - } + llvm::Value* left_arg = read_item(list, LLVM::CreateLoad(*builder, i), + LLVM::is_llvm_struct(item_type)); + llvm::Value* is_item_not_equal = builder->CreateNot( + llvm_utils->is_equal_by_value( + left_arg, item, + module, item_type) + ); llvm::Value *cond = builder->CreateAnd(is_item_not_equal, builder->CreateICmpSGT(current_end_point, LLVM::CreateLoad(*builder, i))); @@ -484,7 +595,7 @@ namespace LFortran { } void LLVMList::remove(llvm::Value* list, llvm::Value* item, - ASR::ttypeType item_type, llvm::Module& module) { + ASR::ttype_t* item_type, llvm::Module& module) { llvm::Type* pos_type = llvm::Type::getInt32Ty(context); llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); @@ -536,6 +647,13 @@ namespace LFortran { builder->CreateStore(end_point, end_point_ptr); } + void LLVMList::list_clear(llvm::Value* list) { + llvm::Value* end_point_ptr = get_pointer_to_current_end_point(list); + llvm::Value* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)); + LLVM::CreateStore(*builder, zero, end_point_ptr); + } + LLVMTuple::LLVMTuple(llvm::LLVMContext& context_, LLVMUtils* llvm_utils_, @@ -577,14 +695,31 @@ namespace LFortran { } void LLVMTuple::tuple_deepcopy(llvm::Value* src, llvm::Value* dest, - std::string& type_code) { + ASR::Tuple_t* tuple_type, llvm::Module& module) { LFORTRAN_ASSERT(src->getType() == dest->getType()); - size_t n_elements = typecode2tupletype[type_code].second; - for( size_t i = 0; i < n_elements; i++ ) { - llvm::Value* src_item = read_item(src, i, false); + for( size_t i = 0; i < tuple_type->n_type; i++ ) { + llvm::Value* src_item = read_item(src, i, LLVM::is_llvm_struct( + tuple_type->m_type[i])); llvm::Value* dest_item_ptr = read_item(dest, i, true); - builder->CreateStore(src_item, dest_item_ptr); + llvm_utils->deepcopy(src_item, dest_item_ptr, + tuple_type->m_type[i], module); + } + } + + llvm::Value* LLVMTuple::check_tuple_equality(llvm::Value* t1, llvm::Value* t2, + ASR::Tuple_t* tuple_type, + llvm::LLVMContext& context, + llvm::IRBuilder<>* builder, + llvm::Module& module) { + llvm::Value* is_equal = llvm::ConstantInt::get(context, llvm::APInt(1, 1)); + for( size_t i = 0; i < tuple_type->n_type; i++ ) { + llvm::Value* t1i = llvm_utils->tuple_api->read_item(t1, i); + llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i); + llvm::Value* is_t1_eq_t2 = llvm_utils->is_equal_by_value(t1i, t2i, module, + tuple_type->m_type[i]); + is_equal = builder->CreateAnd(is_equal, is_t1_eq_t2); } + return is_equal; } } // namespace LFortran diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 324271e7d5..3f2fa432bb 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -49,8 +49,17 @@ namespace LFortran { llvm::IRBuilder<> &builder, llvm::Value* arg_size); llvm::Value* lfortran_realloc(llvm::LLVMContext &context, llvm::Module &module, llvm::IRBuilder<> &builder, llvm::Value* ptr, llvm::Value* arg_size); + static inline bool is_llvm_struct(ASR::ttype_t* asr_type) { + return ASR::is_a(*asr_type) || + ASR::is_a(*asr_type) || + ASR::is_a(*asr_type) || + ASR::is_a(*asr_type); + } } + class LLVMList; + class LLVMTuple; + class LLVMUtils { private: @@ -60,6 +69,9 @@ namespace LFortran { public: + LLVMTuple* tuple_api; + LLVMList* list_api; + LLVMUtils(llvm::LLVMContext& context, llvm::IRBuilder<>* _builder); @@ -78,6 +90,12 @@ namespace LFortran { llvm::Value* lfortran_str_cmp(llvm::Value* left_arg, llvm::Value* right_arg, std::string runtime_func_name, llvm::Module& module); + llvm::Value* is_equal_by_value(llvm::Value* left, llvm::Value* right, + llvm::Module& module, ASR::ttype_t* asr_type); + + void deepcopy(llvm::Value* src, llvm::Value* dest, + ASR::ttype_t* asr_type, llvm::Module& module); + }; // LLVMUtils class LLVMList { @@ -98,7 +116,7 @@ namespace LFortran { public: LLVMList(llvm::LLVMContext& context_, LLVMUtils* llvm_utils, - llvm::IRBuilder<>* builder); + llvm::IRBuilder<>* builder); llvm::Type* get_list_type(llvm::Type* el_type, std::string& type_code, int32_t type_size); @@ -114,7 +132,7 @@ namespace LFortran { llvm::Value* get_pointer_to_current_capacity(llvm::Value* list); void list_deepcopy(llvm::Value* src, llvm::Value* dest, - std::string& src_type_code, + ASR::List_t* list_type, llvm::Module& module); llvm::Value* read_item(llvm::Value* list, llvm::Value* pos, @@ -122,21 +140,27 @@ namespace LFortran { llvm::Value* len(llvm::Value* list); + void write_item(llvm::Value* list, llvm::Value* pos, + llvm::Value* item, ASR::ttype_t* asr_type, + llvm::Module& module); + void write_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item); void append(llvm::Value* list, llvm::Value* item, - llvm::Module& module, std::string& type_code); + ASR::ttype_t* asr_type, llvm::Module& module); void insert_item(llvm::Value* list, llvm::Value* pos, - llvm::Value* item, llvm::Module& module, - std::string& type_code); + llvm::Value* item, ASR::ttype_t* asr_type, + llvm::Module& module); void remove(llvm::Value* list, llvm::Value* item, - ASR::ttypeType item_type, llvm::Module& module); + ASR::ttype_t* item_type, llvm::Module& module); + + void list_clear(llvm::Value* list); llvm::Value* find_item_position(llvm::Value* list, - llvm::Value* item, ASR::ttypeType item_type, + llvm::Value* item, ASR::ttype_t* item_type, llvm::Module& module); }; @@ -167,7 +191,11 @@ namespace LFortran { bool get_pointer=false); void tuple_deepcopy(llvm::Value* src, llvm::Value* dest, - std::string& type_code); + ASR::Tuple_t* type_code, llvm::Module& module); + + llvm::Value* check_tuple_equality(llvm::Value* t1, llvm::Value* t2, + ASR::Tuple_t* tuple_type, llvm::LLVMContext& context, + llvm::IRBuilder<>* builder, llvm::Module& module); }; } // LFortran diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 9e55bae10b..0310f483ac 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -197,6 +197,9 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab, std::vector &rl_path, bool <ypes, const std::function err) { + if( module_name == "copy" ) { + return nullptr; + } ltypes = false; LFORTRAN_ASSERT(symtab); if (symtab->get_scope().find(module_name) != symtab->get_scope().end()) { @@ -4422,6 +4425,14 @@ class BodyVisitor : public CommonVisitor { ASRUtils::type_to_str(type) + " type.", x.base.base.loc); } return; + } else if( call_name == "deepcopy" ) { + if( args.size() != 1 ) { + throw SemanticError("deepcopy only accepts one argument, found " + + std::to_string(args.size()) + " instead.", + x.base.base.loc); + } + tmp = (ASR::asr_t*) args[0].m_value; + return ; } else { // The function was not found and it is not intrinsic throw SemanticError("Function '" + call_name + "' is not declared and not intrinsic", diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 53ef5b7870..6d0e7715da 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -21,6 +21,7 @@ struct AttributeHandler { {"int@bit_length", &eval_int_bit_length}, {"list@append", &eval_list_append}, {"list@remove", &eval_list_remove}, + {"list@clear", &eval_list_clear}, {"list@insert", &eval_list_insert}, {"list@pop", &eval_list_pop}, {"set@pop", &eval_set_pop}, @@ -187,6 +188,23 @@ struct AttributeHandler { return make_ListPop_t(al, loc, s, idx, list_type, nullptr); } + static ASR::asr_t* eval_list_clear(ASR::expr_t *s, Allocator &al, + const Location &loc, Vec &args, diag::Diagnostics & diag) { + if (args.size() != 0) { + diag.add(diag::Diagnostic( + "Incorrect number of arguments in 'clear', it accepts no argument", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("incorrect number of arguments in clear (found: " + + std::to_string(args.size()) + ", expected: 0)", + {loc}) + }) + ); + throw SemanticAbort(); + } + + return make_ListClear_t(al, loc, s); + } + static ASR::asr_t* eval_set_pop(ASR::expr_t *s, Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &/*diag*/) { if (args.size() != 0) {