diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 03e9b32f6d..a6b3b73e80 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1756,6 +1756,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_Program(const ASR::Program_t &x) { + bool is_dict_present_copy = dict_api->is_dict_present; + dict_api->is_dict_present = false; llvm_goto_targets.clear(); // Generate code for nested subroutines and functions first: for (auto &item : x.m_symtab->get_scope()) { @@ -1784,6 +1786,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value *ret_val2 = llvm::ConstantInt::get(context, llvm::APInt(32, 0)); builder->CreateRet(ret_val2); + dict_api->is_dict_present = is_dict_present_copy; } /* @@ -2586,6 +2589,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_Function(const ASR::Function_t &x) { + bool is_dict_present_copy = dict_api->is_dict_present; + dict_api->is_dict_present = false; llvm_goto_targets.clear(); instantiate_function(x); if (x.m_deftype == ASR::deftypeType::Interface) { @@ -2596,6 +2601,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor visit_procedures(x); generate_function(x); parent_function = nullptr; + dict_api->is_dict_present = is_dict_present_copy; } void instantiate_function(const ASR::Function_t &x){ @@ -3577,6 +3583,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void visit_WhileLoop(const ASR::WhileLoop_t &x) { + dict_api->set_iterators(); llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); @@ -3598,6 +3605,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor // end start_new_block(loopend); + dict_api->reset_iterators(); } void visit_Exit(const ASR::Exit_t & /* x */) { diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 9b62f64151..480c6297ec 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -282,7 +282,8 @@ namespace LFortran { llvm_utils(std::move(llvm_utils_)), builder(std::move(builder_)), pos_ptr(nullptr), is_key_matching_var(nullptr), - are_iterators_set(false) { + idx_ptr(nullptr), are_iterators_set(false), + is_dict_present(false) { } llvm::Type* LLVMList::get_list_type(llvm::Type* el_type, std::string& type_code, @@ -301,6 +302,7 @@ namespace LFortran { llvm::Type* LLVMDict::get_dict_type(std::string key_type_code, std::string value_type_code, int32_t key_type_size, int32_t value_type_size, llvm::Type* key_type, llvm::Type* value_type) { + is_dict_present = true; std::pair llvm_key = std::make_pair(key_type_code, value_type_code); if( typecode2dicttype.find(llvm_key) != typecode2dicttype.end() ) { return std::get<0>(typecode2dicttype[llvm_key]); @@ -554,17 +556,26 @@ namespace LFortran { } void LLVMDict::set_iterators() { - if( are_iterators_set ) { + if( are_iterators_set || !is_dict_present ) { return ; } - pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); - is_key_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr, "pos_ptr"); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), pos_ptr); + is_key_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr, + "is_key_matching_var"); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), + llvm::APInt(1, 0)), is_key_matching_var); + idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr, "idx_ptr"); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), idx_ptr); are_iterators_set = true; } void LLVMDict::reset_iterators() { pos_ptr = nullptr; is_key_matching_var = nullptr; + idx_ptr = nullptr; are_iterators_set = false; } @@ -572,7 +583,10 @@ namespace LFortran { llvm::Value* key, llvm::Value* key_list, llvm::Value* key_mask, llvm::Module& module, ASR::ttype_t* key_asr_type) { - set_iterators(); + if( !are_iterators_set ) { + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + is_key_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + } LLVM::CreateStore(*builder, key_hash, pos_ptr); llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); @@ -654,7 +668,6 @@ namespace LFortran { LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 1)), llvm_utils->create_ptr_gep(key_mask, pos)); - reset_iterators(); } llvm::Value* LLVMDict::linear_probing_for_read(llvm::Value* dict, llvm::Value* key_hash, @@ -667,7 +680,6 @@ namespace LFortran { linear_probing(capacity, key_hash, key, key_list, key_mask, module, key_asr_type); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, true, false); - reset_iterators(); return item; } @@ -729,10 +741,9 @@ namespace LFortran { new_key_mask = builder->CreateBitCast(new_key_mask, llvm::Type::getInt1PtrTy(context)); llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); - // TODO: Should be created outside the user loop and not here. - // LLVMDict should treat them as data members and create them - // only if they are NULL - llvm::Value* idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + if( !are_iterators_set ) { + idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), idx_ptr); @@ -790,7 +801,6 @@ namespace LFortran { // end llvm_utils->start_new_block(loopend); - reset_iterators(); // TODO: Free key_list, value_list and key_mask llvm_utils->list_api->free_data(key_list, *module); diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 1ec2c47cfd..b9c04f5fbd 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -223,6 +223,7 @@ namespace LFortran { LLVMUtils* llvm_utils; llvm::IRBuilder<>* builder; llvm::AllocaInst *pos_ptr, *is_key_matching_var; + llvm::AllocaInst *idx_ptr; bool are_iterators_set; std::map, @@ -231,6 +232,8 @@ namespace LFortran { public: + bool is_dict_present; + LLVMDict(llvm::LLVMContext& context_, LLVMUtils* llvm_utils, llvm::IRBuilder<>* builder);