From bfbea2a3e9ea8f1622ac5166cfb6ed381db884e7 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Wed, 17 Aug 2022 20:26:13 +0530 Subject: [PATCH 1/6] Added integration test for dict data structure --- integration_tests/test_dict_01.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 integration_tests/test_dict_01.py diff --git a/integration_tests/test_dict_01.py b/integration_tests/test_dict_01.py new file mode 100644 index 0000000000..27301eee4e --- /dev/null +++ b/integration_tests/test_dict_01.py @@ -0,0 +1,16 @@ +from ltypes import i32, f64 + +def test_dict(): + rollnumber2cpi: dict[i32, f64] = {0: 1.1} + i: i32 + size: i32 = 1000 + + for i in range(1000, 1000 + size): + rollnumber2cpi[i] = float(i/100.0 + 5.0) + + for i in range(1000 + size - 1, 1001, -1): + assert abs(rollnumber2cpi[i] - i/100.0 - 5.0) <= 1e-12 + + assert abs(rollnumber2cpi[0] - 1.1) <= 1e-12 + +test_dict() From 8171bc0ab134b874aa1130556bcc583eab3a4eca Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Wed, 17 Aug 2022 20:27:17 +0530 Subject: [PATCH 2/6] Added _lfortran_calloc intrinsic in C runtime library --- src/libasr/runtime/lfortran_intrinsics.c | 4 ++++ src/libasr/runtime/lfortran_intrinsics.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/libasr/runtime/lfortran_intrinsics.c b/src/libasr/runtime/lfortran_intrinsics.c index ac18806f25..38b5dadbe8 100644 --- a/src/libasr/runtime/lfortran_intrinsics.c +++ b/src/libasr/runtime/lfortran_intrinsics.c @@ -860,6 +860,10 @@ LFORTRAN_API int8_t* _lfortran_realloc(int8_t* ptr, int32_t size) { return (int8_t*) realloc(ptr, size); } +LFORTRAN_API int8_t* _lfortran_calloc(int32_t count, int32_t size) { + return (int8_t*) calloc(count, size); +} + LFORTRAN_API void _lfortran_free(char* ptr) { free((void*)ptr); } diff --git a/src/libasr/runtime/lfortran_intrinsics.h b/src/libasr/runtime/lfortran_intrinsics.h index baccb85bff..e1785a2b85 100644 --- a/src/libasr/runtime/lfortran_intrinsics.h +++ b/src/libasr/runtime/lfortran_intrinsics.h @@ -160,6 +160,8 @@ LFORTRAN_API int _lfortran_str_ord(char** s); LFORTRAN_API char* _lfortran_str_chr(int c); LFORTRAN_API int _lfortran_str_to_int(char** s); LFORTRAN_API char* _lfortran_malloc(int size); +LFORTRAN_API int8_t* _lfortran_realloc(int8_t* ptr, int32_t size); +LFORTRAN_API int8_t* _lfortran_calloc(int32_t count, int32_t size); LFORTRAN_API void _lfortran_free(char* ptr); LFORTRAN_API void _lfortran_string_init(int size_plus_one, char *s); LFORTRAN_API int32_t _lfortran_iand32(int32_t x, int32_t y); From f987f90dacbf18a17421eb277c5bba9a3ade0326 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Wed, 17 Aug 2022 20:27:44 +0530 Subject: [PATCH 3/6] Added interface for LLVMDict and included free_data in LLVMList --- src/libasr/codegen/llvm_utils.h | 104 +++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 3 deletions(-) diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 3f2fa432bb..1ec2c47cfd 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -49,6 +49,10 @@ namespace LFortran { llvm::IRBuilder<> &builder, llvm::Value* arg_size); llvm::Value* lfortran_realloc(llvm::LLVMContext &context, llvm::Module &module, llvm::IRBuilder<> &builder, llvm::Value* ptr, llvm::Value* arg_size); + llvm::Value* lfortran_calloc(llvm::LLVMContext &context, llvm::Module &module, + llvm::IRBuilder<> &builder, llvm::Value* count, llvm::Value* type_size); + llvm::Value* lfortran_free(llvm::LLVMContext &context, llvm::Module &module, + llvm::IRBuilder<> &builder, llvm::Value* ptr); static inline bool is_llvm_struct(ASR::ttype_t* asr_type) { return ASR::is_a(*asr_type) || ASR::is_a(*asr_type) || @@ -59,6 +63,7 @@ namespace LFortran { class LLVMList; class LLVMTuple; + class LLVMDict; class LLVMUtils { @@ -71,6 +76,7 @@ namespace LFortran { LLVMTuple* tuple_api; LLVMList* list_api; + LLVMDict* dict_api; LLVMUtils(llvm::LLVMContext& context, llvm::IRBuilder<>* _builder); @@ -121,6 +127,10 @@ namespace LFortran { llvm::Type* get_list_type(llvm::Type* el_type, std::string& type_code, int32_t type_size); + void list_init(std::string& type_code, llvm::Value* list, + llvm::Module& module, llvm::Value* initial_capacity, + llvm::Value* n); + void list_init(std::string& type_code, llvm::Value* list, llvm::Module& module, int32_t initial_capacity=1, int32_t n=0); @@ -135,17 +145,23 @@ namespace LFortran { ASR::List_t* list_type, llvm::Module& module); + void list_deepcopy(llvm::Value* src, llvm::Value* dest, + ASR::ttype_t* element_type, + llvm::Module& module); + llvm::Value* read_item(llvm::Value* list, llvm::Value* pos, - bool get_pointer=false); + bool get_pointer=false, bool check_index_bound=true); llvm::Value* len(llvm::Value* list); + void check_index_within_bounds(llvm::Value* list, llvm::Value* pos); + void write_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, ASR::ttype_t* asr_type, - llvm::Module& module); + llvm::Module& module, bool check_index_bound=true); void write_item(llvm::Value* list, llvm::Value* pos, - llvm::Value* item); + llvm::Value* item, bool check_index_bound=true); void append(llvm::Value* list, llvm::Value* item, ASR::ttype_t* asr_type, llvm::Module& module); @@ -162,6 +178,8 @@ namespace LFortran { llvm::Value* find_item_position(llvm::Value* list, llvm::Value* item, ASR::ttype_t* item_type, llvm::Module& module); + + void free_data(llvm::Value* list, llvm::Module& module); }; class LLVMTuple { @@ -198,6 +216,86 @@ namespace LFortran { llvm::IRBuilder<>* builder, llvm::Module& module); }; + class LLVMDict { + private: + + llvm::LLVMContext& context; + LLVMUtils* llvm_utils; + llvm::IRBuilder<>* builder; + llvm::AllocaInst *pos_ptr, *is_key_matching_var; + bool are_iterators_set; + + std::map, + std::tuple, + std::pair>> typecode2dicttype; + + public: + + LLVMDict(llvm::LLVMContext& context_, + LLVMUtils* llvm_utils, + llvm::IRBuilder<>* builder); + + llvm::Type* get_dict_type(std::string key_type_code, std::string value_type_code, + int32_t key_type_size, int32_t value_type_size, + llvm::Type* key_type, llvm::Type* value_type); + + void dict_init(std::string key_type_code, std::string value_type_code, + llvm::Value* dict, llvm::Module* module, size_t initial_capacity); + + llvm::Value* get_key_list(llvm::Value* dict); + + llvm::Value* get_value_list(llvm::Value* dict); + + llvm::Value* get_pointer_to_occupancy(llvm::Value* dict); + + llvm::Value* get_pointer_to_capacity(llvm::Value* dict); + + llvm::Value* get_key_hash(llvm::Value* capacity, llvm::Value* key, + ASR::ttype_t* key_asr_type, llvm::Module& module); + + void linear_probing(llvm::Value* capacity, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* key_list, + llvm::Value* key_mask, llvm::Module& module, + ASR::ttype_t* key_asr_type); + + void linear_probing_for_write(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* value, + llvm::Module& module, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type); + + llvm::Value* linear_probing_for_read(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type); + + void rehash(llvm::Value* dict, llvm::Module* module, + ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type); + + void rehash_all_at_once_if_needed(llvm::Value* dict, + llvm::Module* module, + ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type); + + void write_item(llvm::Value* dict, llvm::Value* key, + llvm::Value* value, llvm::Module* module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + + llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, + llvm::Module& module, ASR::ttype_t* key_asr_type, + bool get_pointer=false); + + llvm::Value* get_pointer_to_keymask(llvm::Value* dict); + + void set_iterators(); + + void reset_iterators(); + + void dict_deepcopy(llvm::Value* src, llvm::Value* dest, + ASR::Dict_t* dict_type, llvm::Module* module); + + llvm::Value* len(llvm::Value* dict); + }; + } // LFortran #endif // LFORTRAN_LLVM_UTILS_H From 3bb54d3fa4c1f5c603c0e021c16419082354d868 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Wed, 17 Aug 2022 20:28:23 +0530 Subject: [PATCH 4/6] Implemented LLVMDict interface using list_api --- src/libasr/codegen/llvm_utils.cpp | 518 +++++++++++++++++++++++++++++- 1 file changed, 511 insertions(+), 7 deletions(-) diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 23484baa0d..9b62f64151 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -49,6 +49,23 @@ namespace LFortran { return builder.CreateCall(fn, args); } + llvm::Value* lfortran_calloc(llvm::LLVMContext &context, llvm::Module &module, + llvm::IRBuilder<> &builder, llvm::Value* count, llvm::Value* type_size) { + std::string func_name = "_lfortran_calloc"; + llvm::Function *fn = module.getFunction(func_name); + if (!fn) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getInt8PtrTy(context), { + llvm::Type::getInt32Ty(context), + llvm::Type::getInt32Ty(context) + }, true); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, func_name, module); + } + std::vector args = {count, type_size}; + return builder.CreateCall(fn, args); + } + llvm::Value* lfortran_realloc(llvm::LLVMContext &context, llvm::Module &module, llvm::IRBuilder<> &builder, llvm::Value* ptr, llvm::Value* arg_size) { std::string func_name = "_lfortran_realloc"; @@ -68,6 +85,24 @@ namespace LFortran { }; return builder.CreateCall(fn, args); } + + llvm::Value* lfortran_free(llvm::LLVMContext &context, llvm::Module &module, + llvm::IRBuilder<> &builder, llvm::Value* ptr) { + std::string func_name = "_lfortran_free"; + llvm::Function *fn = module.getFunction(func_name); + if (!fn) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getVoidTy(context), { + llvm::Type::getInt8PtrTy(context) + }, true); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, func_name, module); + } + std::vector args = { + builder.CreateBitCast(ptr, llvm::Type::getInt8PtrTy(context)), + }; + return builder.CreateCall(fn, args); + } } // namespace LLVM LLVMUtils::LLVMUtils(llvm::LLVMContext& context, @@ -240,6 +275,16 @@ namespace LFortran { llvm_utils(std::move(llvm_utils_)), builder(std::move(builder_)) {} + LLVMDict::LLVMDict(llvm::LLVMContext& context_, + LLVMUtils* llvm_utils_, + llvm::IRBuilder<>* builder_): + context(context_), + llvm_utils(std::move(llvm_utils_)), + builder(std::move(builder_)), + pos_ptr(nullptr), is_key_matching_var(nullptr), + are_iterators_set(false) { + } + llvm::Type* LLVMList::get_list_type(llvm::Type* el_type, std::string& type_code, int32_t type_size) { if( typecode2listtype.find(type_code) != typecode2listtype.end() ) { @@ -253,6 +298,28 @@ namespace LFortran { return list_desc; } + llvm::Type* LLVMDict::get_dict_type(std::string key_type_code, std::string value_type_code, + int32_t key_type_size, int32_t value_type_size, + llvm::Type* key_type, llvm::Type* value_type) { + std::pair llvm_key = std::make_pair(key_type_code, value_type_code); + if( typecode2dicttype.find(llvm_key) != typecode2dicttype.end() ) { + return std::get<0>(typecode2dicttype[llvm_key]); + } + + llvm::Type* key_list_type = llvm_utils->list_api->get_list_type(key_type, + key_type_code, key_type_size); + llvm::Type* value_list_type = llvm_utils->list_api->get_list_type(value_type, + value_type_code, value_type_size); + std::vector dict_type_vec = {llvm::Type::getInt32Ty(context), + key_list_type, value_list_type, + llvm::Type::getInt1PtrTy(context)}; + llvm::Type* dict_desc = llvm::StructType::create(context, dict_type_vec, "dict"); + typecode2dicttype[llvm_key] = std::make_tuple(dict_desc, + std::make_pair(key_type_size, value_type_size), + std::make_pair(key_type, value_type)); + return dict_desc; + } + llvm::Value* LLVMList::get_pointer_to_list_data(llvm::Value* list) { return llvm_utils->create_gep(list, 2); } @@ -268,7 +335,7 @@ namespace LFortran { void LLVMList::list_init(std::string& type_code, llvm::Value* list, llvm::Module& module, int32_t initial_capacity, int32_t n) { if( typecode2listtype.find(type_code) == typecode2listtype.end() ) { - LCompilersException("list for " + type_code + " not declared yet."); + throw LCompilersException("list for " + type_code + " not declared yet."); } int32_t type_size = std::get<1>(typecode2listtype[type_code]); llvm::Value* arg_size = llvm::ConstantInt::get(context, @@ -286,10 +353,74 @@ namespace LFortran { builder->CreateStore(current_capacity, get_pointer_to_current_capacity(list)); } + void LLVMList::list_init(std::string& type_code, llvm::Value* list, + llvm::Module& module, llvm::Value* initial_capacity, + llvm::Value* n) { + if( typecode2listtype.find(type_code) == typecode2listtype.end() ) { + throw LCompilersException("list for " + type_code + " not declared yet."); + } + int32_t type_size = std::get<1>(typecode2listtype[type_code]); + llvm::Value* llvm_type_size = llvm::ConstantInt::get(context, llvm::APInt(32, type_size)); + llvm::Value* arg_size = builder->CreateMul(llvm_type_size, initial_capacity); + llvm::Value* list_data = LLVM::lfortran_malloc(context, module, *builder, arg_size); + + llvm::Type* el_type = std::get<2>(typecode2listtype[type_code]); + list_data = builder->CreateBitCast(list_data, el_type->getPointerTo()); + llvm::Value* list_data_ptr = get_pointer_to_list_data(list); + builder->CreateStore(list_data, list_data_ptr); + builder->CreateStore(n, get_pointer_to_current_end_point(list)); + builder->CreateStore(initial_capacity, get_pointer_to_current_capacity(list)); + } + + llvm::Value* LLVMDict::get_key_list(llvm::Value* dict) { + return llvm_utils->create_gep(dict, 1); + } + + llvm::Value* LLVMDict::get_value_list(llvm::Value* dict) { + return llvm_utils->create_gep(dict, 2); + } + + llvm::Value* LLVMDict::get_pointer_to_occupancy(llvm::Value* dict) { + return llvm_utils->create_gep(dict, 0); + } + + llvm::Value* LLVMDict::get_pointer_to_capacity(llvm::Value* dict) { + return llvm_utils->list_api->get_pointer_to_current_capacity( + get_value_list(dict)); + } + + void LLVMDict::dict_init(std::string key_type_code, std::string value_type_code, + llvm::Value* dict, llvm::Module* module, size_t initial_capacity) { + llvm::Value* n_ptr = get_pointer_to_occupancy(dict); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), n_ptr); + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm_utils->list_api->list_init(key_type_code, key_list, *module, + initial_capacity, initial_capacity); + llvm_utils->list_api->list_init(value_type_code, value_list, *module, + initial_capacity, initial_capacity); + llvm::DataLayout data_layout(module); + size_t bool_size = data_layout.getTypeAllocSize(llvm::Type::getInt1Ty(context)); + llvm::Value* llvm_capacity = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, initial_capacity)); + llvm::Value* llvm_bool_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, bool_size)); + llvm::Value* key_mask = LLVM::lfortran_calloc(context, *module, *builder, llvm_capacity, + llvm_bool_size); + key_mask = builder->CreateBitCast(key_mask, llvm::Type::getInt1PtrTy(context)); + LLVM::CreateStore(*builder, key_mask, get_pointer_to_keymask(dict)); + } + void LLVMList::list_deepcopy(llvm::Value* src, llvm::Value* dest, ASR::List_t* list_type, llvm::Module& module) { + list_deepcopy(src, dest, list_type->m_type, module); + } + + void LLVMList::list_deepcopy(llvm::Value* src, llvm::Value* dest, + ASR::ttype_t* element_type, llvm::Module& module) { LFORTRAN_ASSERT(src->getType() == dest->getType()); - std::string src_type_code = ASRUtils::get_type_code(list_type->m_type); + std::string src_type_code = ASRUtils::get_type_code(element_type); llvm::Value* src_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(src)); llvm::Value* src_capacity = LLVM::CreateLoad(*builder, get_pointer_to_current_capacity(src)); llvm::Value* dest_end_point_ptr = get_pointer_to_current_end_point(dest); @@ -314,8 +445,11 @@ namespace LFortran { // will be deep copied and rest will be shallow copied. The importance of this case // can be figured out by goind through, integration_tests/test_list_06.py and // integration_tests/test_list_07.py. - if( LLVM::is_llvm_struct(list_type->m_type) ) { + if( LLVM::is_llvm_struct(element_type) ) { builder->CreateStore(copy_data, get_pointer_to_list_data(dest)); + // TODO: Should be created outside the user loop and not here. + // LLVMList should treat them as data members and create them + // only if they are NULL llvm::AllocaInst *pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), @@ -340,7 +474,7 @@ namespace LFortran { llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm::Value* srci = read_item(src, pos, true); llvm::Value* desti = read_item(dest, pos, true); - llvm_utils->deepcopy(srci, desti, list_type->m_type, module); + llvm_utils->deepcopy(srci, desti, element_type, module); llvm::Value* tmp = builder->CreateAdd( pos, llvm::ConstantInt::get(context, llvm::APInt(32, 1))); @@ -358,22 +492,371 @@ namespace LFortran { } } + void LLVMDict::dict_deepcopy(llvm::Value* src, llvm::Value* dest, + ASR::Dict_t* dict_type, llvm::Module* module) { + LFORTRAN_ASSERT(src->getType() == dest->getType()); + llvm::Value* src_occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(src)); + llvm::Value* dest_occupancy_ptr = get_pointer_to_occupancy(dest); + LLVM::CreateStore(*builder, src_occupancy, dest_occupancy_ptr); + + llvm::Value* src_key_list = get_key_list(src); + llvm::Value* dest_key_list = get_key_list(dest); + llvm_utils->list_api->list_deepcopy(src_key_list, dest_key_list, + dict_type->m_key_type, *module); + + llvm::Value* src_value_list = get_value_list(src); + llvm::Value* dest_value_list = get_value_list(dest); + llvm_utils->list_api->list_deepcopy(src_value_list, dest_value_list, + dict_type->m_value_type, *module); + + llvm::Value* src_key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(src)); + llvm::Value* dest_key_mask_ptr = get_pointer_to_keymask(dest); + llvm::DataLayout data_layout(module); + size_t bool_size = data_layout.getTypeAllocSize(llvm::Type::getInt1Ty(context)); + llvm::Value* llvm_bool_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, bool_size)); + llvm::Value* src_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(src)); + llvm::Value* dest_key_mask = LLVM::lfortran_calloc(context, *module, *builder, src_capacity, + llvm_bool_size); + dest_key_mask = builder->CreateBitCast(dest_key_mask, llvm::Type::getInt1PtrTy(context)); + builder->CreateMemCpy(dest_key_mask, llvm::MaybeAlign(), src_key_mask, + llvm::MaybeAlign(), builder->CreateMul(src_capacity, llvm_bool_size)); + LLVM::CreateStore(*builder, dest_key_mask, dest_key_mask_ptr); + } + + void LLVMList::check_index_within_bounds(llvm::Value* /*list*/, llvm::Value* /*pos*/) { + + } + void LLVMList::write_item(llvm::Value* list, llvm::Value* pos, llvm::Value* item, ASR::ttype_t* asr_type, - llvm::Module& module) { + llvm::Module& module, bool check_index_bound) { + if( check_index_bound ) { + check_index_within_bounds(list, pos); + } llvm::Value* list_data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list)); llvm::Value* element_ptr = llvm_utils->create_ptr_gep(list_data, pos); llvm_utils->deepcopy(item, element_ptr, asr_type, module); } void LLVMList::write_item(llvm::Value* list, llvm::Value* pos, - llvm::Value* item) { + llvm::Value* item, bool check_index_bound) { + if( check_index_bound ) { + check_index_within_bounds(list, pos); + } llvm::Value* list_data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list)); llvm::Value* element_ptr = llvm_utils->create_ptr_gep(list_data, pos); LLVM::CreateStore(*builder, item, element_ptr); } - llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool get_pointer) { + llvm::Value* LLVMDict::get_pointer_to_keymask(llvm::Value* dict) { + return llvm_utils->create_gep(dict, 3); + } + + void LLVMDict::set_iterators() { + if( are_iterators_set ) { + return ; + } + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + is_key_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + are_iterators_set = true; + } + + void LLVMDict::reset_iterators() { + pos_ptr = nullptr; + is_key_matching_var = nullptr; + are_iterators_set = false; + } + + void LLVMDict::linear_probing(llvm::Value* capacity, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* key_list, + llvm::Value* key_mask, llvm::Module& module, + ASR::ttype_t* key_asr_type) { + set_iterators(); + LLVM::CreateStore(*builder, key_hash, pos_ptr); + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* is_key_set = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, pos)); + llvm::Value* is_key_matching = llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), + llvm::APInt(1, 0)); + LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var); + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + builder->CreateCondBr(is_key_set, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + llvm::Value* original_key = llvm_utils->list_api->read_item(key_list, pos, + LLVM::is_llvm_struct(key_asr_type), false); + is_key_matching = llvm_utils->is_equal_by_value(key, original_key, module, + key_asr_type); + LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var); + } + builder->CreateBr(mergeBB); + + llvm_utils->start_new_block(elseBB); + llvm_utils->start_new_block(mergeBB); + // TODO: Allow safe exit if pos becomes key_hash again. + // Ideally should not happen as dict will be resized once + // load factor touches a threshold (which will always be less than 1) + // so there will be some key which will not be set. However for safety + // we can add an exit from the loop with a error message. + llvm::Value *cond = builder->CreateAnd(is_key_set, builder->CreateNot(LLVM::CreateLoad(*builder, is_key_matching_var))); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + pos = builder->CreateAdd(pos, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + pos = builder->CreateSRem(pos, capacity); + LLVM::CreateStore(*builder, pos, pos_ptr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + } + + void LLVMDict::linear_probing_for_write(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Value* value, + llvm::Module& module, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + linear_probing(capacity, key_hash, key, key_list, key_mask, module, key_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm_utils->list_api->write_item(key_list, pos, key, + key_asr_type, module, false); + llvm_utils->list_api->write_item(value_list, pos, value, + value_asr_type, module, false); + llvm::Value* is_slot_empty = builder->CreateNot(LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, pos))); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); + is_slot_empty = builder->CreateZExt(is_slot_empty, llvm::Type::getInt32Ty(context)); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + LLVM::CreateStore(*builder, builder->CreateAdd(occupancy, is_slot_empty), + occupancy_ptr); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 1)), + llvm_utils->create_ptr_gep(key_mask, pos)); + reset_iterators(); + } + + llvm::Value* LLVMDict::linear_probing_for_read(llvm::Value* dict, llvm::Value* key_hash, + llvm::Value* key, llvm::Module& module, + ASR::ttype_t* key_asr_type) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + linear_probing(capacity, key_hash, key, key_list, key_mask, module, key_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, true, false); + reset_iterators(); + return item; + } + + llvm::Value* LLVMDict::get_key_hash(llvm::Value* capacity, llvm::Value* key, + ASR::ttype_t* key_asr_type, llvm::Module& /*module*/) { + // Write specialised hash functions for intrinsic types + // This is to avoid unnecessary calls to C-runtime and do + // as much as possible in LLVM directly. + switch( key_asr_type->type ) { + case ASR::ttypeType::Integer: { + // Simple modulo with the capacity of the dict. + // We can update it later to do a better hash function + // which produces lesser collisions. + return builder->CreateSRem(key, capacity); + } + default: { + throw LCompilersException("Hashing " + ASRUtils::type_to_str_python(key_asr_type) + + " isn't implemented yet."); + } + } + } + + void LLVMDict::rehash(llvm::Value* dict, llvm::Module* module, + ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type) { + llvm::Value* capacity_ptr = get_pointer_to_capacity(dict); + llvm::Value* old_capacity = LLVM::CreateLoad(*builder, capacity_ptr); + llvm::Value* capacity = builder->CreateMul(old_capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 2))); + capacity = builder->CreateAdd(capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, capacity, capacity_ptr); + + std::string key_type_code = ASRUtils::get_type_code(key_asr_type); + std::string value_type_code = ASRUtils::get_type_code(value_asr_type); + std::pair dict_type_key = std::make_pair(key_type_code, value_type_code); + llvm::Type* key_llvm_type = std::get<2>(typecode2dicttype[dict_type_key]).first; + llvm::Type* value_llvm_type = std::get<2>(typecode2dicttype[dict_type_key]).second; + int32_t key_type_size = std::get<1>(typecode2dicttype[dict_type_key]).first; + int32_t value_type_size = std::get<1>(typecode2dicttype[dict_type_key]).second; + + llvm::Value* key_list = get_key_list(dict); + llvm::Value* new_key_list = builder->CreateAlloca(llvm_utils->list_api->get_list_type(key_llvm_type, + key_type_code, key_type_size), nullptr); + llvm_utils->list_api->list_init(key_type_code, new_key_list, *module, capacity, capacity); + + llvm::Value* value_list = get_value_list(dict); + llvm::Value* new_value_list = builder->CreateAlloca(llvm_utils->list_api->get_list_type(value_llvm_type, + value_type_code, value_type_size), nullptr); + llvm_utils->list_api->list_init(value_type_code, new_value_list, *module, capacity, capacity); + + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::DataLayout data_layout(module); + size_t bool_size = data_layout.getTypeAllocSize(llvm::Type::getInt1Ty(context)); + llvm::Value* llvm_bool_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, bool_size)); + llvm::Value* new_key_mask = LLVM::lfortran_calloc(context, *module, *builder, capacity, + llvm_bool_size); + new_key_mask = builder->CreateBitCast(new_key_mask, llvm::Type::getInt1PtrTy(context)); + + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + // TODO: Should be created outside the user loop and not here. + // LLVMDict should treat them as data members and create them + // only if they are NULL + llvm::Value* idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), idx_ptr); + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT(old_capacity, LLVM::CreateLoad(*builder, idx_ptr)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr); + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* is_key_set = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, idx)); + builder->CreateCondBr(is_key_set, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + llvm::Value* key = llvm_utils->list_api->read_item(key_list, idx, + LLVM::is_llvm_struct(key_asr_type), false); + llvm::Value* value = llvm_utils->list_api->read_item(value_list, idx, + LLVM::is_llvm_struct(value_asr_type), false); + llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module); + linear_probing(current_capacity, key_hash, key, new_key_list, + new_key_mask, *module, key_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* key_dest = llvm_utils->list_api->read_item(new_key_list, pos, + true, false); + llvm_utils->deepcopy(key, key_dest, key_asr_type, *module); + llvm::Value* value_dest = llvm_utils->list_api->read_item(new_value_list, pos, + true, false); + llvm_utils->deepcopy(value, value_dest, value_asr_type, *module); + llvm::Value* key_mask_dest = llvm_utils->create_ptr_gep(new_key_mask, pos); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), + llvm::APInt(1, 1)), key_mask_dest); + } + builder->CreateBr(mergeBB); + + llvm_utils->start_new_block(elseBB); + llvm_utils->start_new_block(mergeBB); + idx = builder->CreateAdd(idx, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, idx, idx_ptr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + reset_iterators(); + + // TODO: Free key_list, value_list and key_mask + llvm_utils->list_api->free_data(key_list, *module); + llvm_utils->list_api->free_data(value_list, *module); + LLVM::lfortran_free(context, *module, *builder, key_mask); + LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, new_key_list), key_list); + LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, new_value_list), value_list); + LLVM::CreateStore(*builder, new_key_mask, get_pointer_to_keymask(dict)); + } + + void LLVMDict::rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + + llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + occupancy = builder->CreateAdd(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + occupancy = builder->CreateSIToFP(occupancy, llvm::Type::getFloatTy(context)); + capacity = builder->CreateSIToFP(capacity, llvm::Type::getFloatTy(context)); + llvm::Value* load_factor = builder->CreateFDiv(occupancy, capacity); + // Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor + llvm::Value* load_factor_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context), + llvm::APFloat((float) 0.6)); + builder->CreateCondBr(builder->CreateFCmpOGE(load_factor, load_factor_threshold), thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + rehash(dict, module, key_asr_type, value_asr_type); + } + builder->CreateBr(mergeBB); + + llvm_utils->start_new_block(elseBB); + llvm_utils->start_new_block(mergeBB); + } + + void LLVMDict::write_item(llvm::Value* dict, llvm::Value* key, + llvm::Value* value, llvm::Module* module, + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { + rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type); + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module); + linear_probing_for_write(dict, key_hash, key, value, *module, + key_asr_type, value_asr_type); + } + + llvm::Value* LLVMDict::read_item(llvm::Value* dict, llvm::Value* key, + llvm::Module& module, ASR::ttype_t* key_asr_type, + bool get_pointer) { + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, module); + llvm::Value* value_ptr = linear_probing_for_read(dict, key_hash, key, module, + key_asr_type); + if( get_pointer ) { + return value_ptr; + } + return LLVM::CreateLoad(*builder, value_ptr); + } + + llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool get_pointer, + bool check_index_bound) { + if( check_index_bound ) { + check_index_within_bounds(list, pos); + } llvm::Value* list_data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list)); llvm::Value* element_ptr = llvm_utils->create_ptr_gep(list_data, pos); if( get_pointer ) { @@ -386,6 +869,10 @@ namespace LFortran { return LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); } + llvm::Value* LLVMDict::len(llvm::Value* dict) { + return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict)); + } + void LLVMList::resize_if_needed(llvm::Value* list, llvm::Value* n, llvm::Value* capacity, int32_t type_size, llvm::Type* el_type, llvm::Module& module) { @@ -465,10 +952,16 @@ namespace LFortran { * list[pos] = item; */ + // TODO: Should be created outside the user loop and not here. + // LLVMList should treat them as data members and create them + // only if they are NULL llvm::AllocaInst *tmp_ptr = builder->CreateAlloca(el_type, nullptr); LLVM::CreateStore(*builder, read_item(list, pos, false), tmp_ptr); llvm::Value* tmp = nullptr; + // TODO: Should be created outside the user loop and not here. + // LLVMList should treat them as data members and create them + // only if they are NULL llvm::AllocaInst *pos_ptr = builder->CreateAlloca( llvm::Type::getInt32Ty(context), nullptr); LLVM::CreateStore(*builder, pos, pos_ptr); @@ -515,6 +1008,9 @@ namespace LFortran { llvm::Type* pos_type = llvm::Type::getInt32Ty(context); llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); + // TODO: Should be created outside the user loop and not here. + // LLVMList should treat them as data members and create them + // only if they are NULL llvm::AllocaInst *i = builder->CreateAlloca(pos_type, nullptr); LLVM::CreateStore(*builder, llvm::ConstantInt::get( context, llvm::APInt(32, 0)), i); @@ -599,6 +1095,9 @@ namespace LFortran { llvm::Type* pos_type = llvm::Type::getInt32Ty(context); llvm::Value* current_end_point = LLVM::CreateLoad(*builder, get_pointer_to_current_end_point(list)); + // TODO: Should be created outside the user loop and not here. + // LLVMList should treat them as data members and create them + // only if they are NULL llvm::AllocaInst *item_pos = builder->CreateAlloca(pos_type, nullptr); llvm::Value* tmp = LLVMList::find_item_position(list, item, item_type, module); LLVM::CreateStore(*builder, tmp, item_pos); @@ -654,6 +1153,11 @@ namespace LFortran { LLVM::CreateStore(*builder, zero, end_point_ptr); } + void LLVMList::free_data(llvm::Value* list, llvm::Module& module) { + llvm::Value* data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list)); + LLVM::lfortran_free(context, module, *builder, data); + } + LLVMTuple::LLVMTuple(llvm::LLVMContext& context_, LLVMUtils* llvm_utils_, From 61443f79eb4f206977d8cd1f1996ff949128d0b4 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Wed, 17 Aug 2022 20:28:54 +0530 Subject: [PATCH 5/6] Use dict_api to implement DictConstant, Assignment of Dicts and DictInsert --- src/libasr/codegen/asr_to_llvm.cpp | 131 +++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 94c8d9f37e..be3c5c0afd 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -221,6 +221,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor std::unique_ptr llvm_utils; std::unique_ptr list_api; std::unique_ptr tuple_api; + std::unique_ptr dict_api; std::unique_ptr arr_descr; uint64_t ptr_loads; @@ -237,6 +238,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_utils(std::make_unique(context, builder.get())), list_api(std::make_unique(context, llvm_utils.get(), builder.get())), tuple_api(std::make_unique(context, llvm_utils.get(), builder.get())), + dict_api(std::make_unique(context, llvm_utils.get(), builder.get())), arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context, builder.get(), llvm_utils.get(), @@ -246,6 +248,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor { llvm_utils->tuple_api = tuple_api.get(); llvm_utils->list_api = list_api.get(); + llvm_utils->dict_api = dict_api.get(); } llvm::Value* CreateLoad(llvm::Value *x) { @@ -1155,6 +1158,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = const_list; } + void visit_DictConstant(const ASR::DictConstant_t& x) { + llvm::Type* const_dict_type = get_dict_type(x.m_type); + llvm::Value* const_dict = builder->CreateAlloca(const_dict_type, nullptr, "const_dict"); + ASR::Dict_t* x_dict = ASR::down_cast(x.m_type); + std::string key_type_code = ASRUtils::get_type_code(x_dict->m_key_type); + std::string value_type_code = ASRUtils::get_type_code(x_dict->m_value_type); + dict_api->dict_init(key_type_code, value_type_code, const_dict, module.get(), x.n_keys); + uint64_t ptr_loads_key = LLVM::is_llvm_struct(x_dict->m_key_type) ? 0 : 2; + uint64_t ptr_loads_value = LLVM::is_llvm_struct(x_dict->m_value_type) ? 0 : 2; + uint64_t ptr_loads_copy = ptr_loads; + for( size_t i = 0; i < x.n_keys; i++ ) { + ptr_loads = ptr_loads_key; + visit_expr(*x.m_keys[i]); + llvm::Value* key = tmp; + ptr_loads = ptr_loads_value; + visit_expr(*x.m_values[i]); + llvm::Value* value = tmp; + dict_api->write_item(const_dict, key, value, module.get(), + x_dict->m_key_type, x_dict->m_value_type); + } + ptr_loads = ptr_loads_copy; + tmp = const_dict; + } + void visit_TupleConstant(const ASR::TupleConstant_t& x) { ASR::Tuple_t* tuple_type = ASR::down_cast(x.m_type); std::string type_code = ASRUtils::get_type_code(tuple_type->m_type, @@ -1235,6 +1262,23 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = list_api->read_item(plist, pos, LLVM::is_llvm_struct(el_type)); } + void visit_DictItem(const ASR::DictItem_t& x) { + ASR::Dict_t* dict_type = ASR::down_cast( + ASRUtils::expr_type(x.m_a)); + uint64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_a); + llvm::Value* pdict = tmp; + + ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type); + this->visit_expr_wrapper(x.m_key, true); + ptr_loads = ptr_loads_copy; + llvm::Value *key = tmp; + + tmp = dict_api->read_item(pdict, key, *module, dict_type->m_key_type, + LLVM::is_llvm_struct(dict_type->m_value_type)); + } + void visit_ListLen(const ASR::ListLen_t& x) { if (x.m_value) { this->visit_expr(*x.m_value); @@ -1248,6 +1292,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } + void visit_DictLen(const ASR::DictLen_t& x) { + if (x.m_value) { + this->visit_expr(*x.m_value); + return ; + } + + uint64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_arg); + ptr_loads = ptr_loads_copy; + llvm::Value* pdict = tmp; + tmp = dict_api->len(pdict); + } + void visit_ListInsert(const ASR::ListInsert_t& x) { ASR::List_t* asr_list = ASR::down_cast( ASRUtils::expr_type(x.m_a)); @@ -1268,6 +1326,26 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor list_api->insert_item(plist, pos, item, asr_list->m_type, *module); } + void visit_DictInsert(const ASR::DictInsert_t& x) { + ASR::Dict_t* dict_type = ASR::down_cast( + ASRUtils::expr_type(x.m_a)); + uint64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_a); + llvm::Value* pdict = tmp; + + ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type); + this->visit_expr_wrapper(x.m_key, true); + llvm::Value *key = tmp; + this->visit_expr_wrapper(x.m_value, true); + llvm::Value *value = tmp; + ptr_loads = ptr_loads_copy; + + dict_api->write_item(pdict, key, value, module.get(), + dict_type->m_key_type, + dict_type->m_value_type); + } + void visit_ListRemove(const ASR::ListRemove_t& x) { ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_a)); uint64_t ptr_loads_copy = ptr_loads; @@ -1717,6 +1795,41 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor return false; } + int32_t get_type_size(ASR::ttype_t* asr_type, llvm::Type* llvm_type, + int32_t a_kind) { + if( LLVM::is_llvm_struct(asr_type) || + ASR::is_a(*asr_type) || + ASR::is_a(*asr_type) ) { + llvm::DataLayout data_layout(module.get()); + return data_layout.getTypeAllocSize(llvm_type); + } + return a_kind; + } + + llvm::Type* get_dict_type(ASR::ttype_t* asr_type) { + ASR::Dict_t* asr_dict = ASR::down_cast(asr_type); + bool is_local_array_type = false, is_local_malloc_array_type = false; + bool is_local_list = false; + ASR::dimension_t* local_m_dims = nullptr; + int local_n_dims = 0; + int local_a_kind = -1; + ASR::storage_typeType local_m_storage = ASR::storage_typeType::Default; + llvm::Type* key_llvm_type = get_type_from_ttype_t(asr_dict->m_key_type, local_m_storage, + is_local_array_type, is_local_malloc_array_type, + is_local_list, local_m_dims, local_n_dims, + local_a_kind); + int32_t key_type_size = get_type_size(asr_dict->m_key_type, key_llvm_type, local_a_kind); + llvm::Type* value_llvm_type = get_type_from_ttype_t(asr_dict->m_value_type, local_m_storage, + is_local_array_type, is_local_malloc_array_type, + is_local_list, local_m_dims, local_n_dims, + local_a_kind); + int32_t value_type_size = get_type_size(asr_dict->m_value_type, value_llvm_type, local_a_kind); + std::string key_type_code = ASRUtils::get_type_code(asr_dict->m_key_type); + std::string value_type_code = ASRUtils::get_type_code(asr_dict->m_value_type); + return dict_api->get_dict_type(key_type_code, value_type_code, key_type_size, + value_type_size, key_llvm_type, value_llvm_type); + } + llvm::Type* get_type_from_ttype_t(ASR::ttype_t* asr_type, ASR::storage_typeType m_storage, bool& is_array_type, bool& is_malloc_array_type, @@ -1865,6 +1978,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_type = list_api->get_list_type(el_llvm_type, el_type_code, type_size); break; } + case (ASR::ttypeType::Dict): { + llvm_type = get_dict_type(asr_type); + break; + } case (ASR::ttypeType::Tuple) : { ASR::Tuple_t* asr_tuple = ASR::down_cast(asr_type); std::string type_code = ASRUtils::get_type_code(asr_tuple->m_type, @@ -2959,6 +3076,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor bool is_value_list = ASR::is_a(*asr_value_type); bool is_target_tuple = ASR::is_a(*asr_target_type); bool is_value_tuple = ASR::is_a(*asr_value_type); + bool is_target_dict = ASR::is_a(*asr_target_type); + bool is_value_dict = ASR::is_a(*asr_value_type); if( is_target_list && is_value_list ) { uint64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; @@ -3014,6 +3133,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } return ; + } else if( is_target_dict && is_value_dict ) { + uint64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_value); + llvm::Value* value_dict = tmp; + this->visit_expr(*x.m_target); + llvm::Value* target_dict = tmp; + ptr_loads = ptr_loads_copy; + ASR::Dict_t* value_dict_type = ASR::down_cast(asr_value_type); + dict_api->dict_deepcopy(value_dict, target_dict, + value_dict_type, module.get()); + return ; } if( ASR::is_a(*ASRUtils::expr_type(x.m_target)) && ASR::is_a(*x.m_value) ) { From 0bdb3f773bb09e8783223ecd4d8b748aa37f834a Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Wed, 17 Aug 2022 20:29:11 +0530 Subject: [PATCH 6/6] Registered test_dict_01 --- integration_tests/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 8fa1141988..8e29f2e10f 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -156,6 +156,7 @@ RUN(NAME test_list_06 LABELS cpython llvm) RUN(NAME test_list_07 LABELS cpython llvm) RUN(NAME test_tuple_01 LABELS cpython llvm) RUN(NAME test_tuple_02 LABELS cpython llvm) +RUN(NAME test_dict_01 LABELS cpython llvm) RUN(NAME modules_01 LABELS cpython llvm) RUN(NAME modules_02 LABELS cpython llvm) RUN(NAME test_math LABELS cpython llvm)