@@ -286,6 +286,13 @@ namespace LFortran {
286
286
is_dict_present (false ) {
287
287
}
288
288
289
+ LLVMDictOptimizedLinearProbing::LLVMDictOptimizedLinearProbing (
290
+ llvm::LLVMContext& context_,
291
+ LLVMUtils* llvm_utils_,
292
+ llvm::IRBuilder<>* builder_):
293
+ LLVMDict (context_, llvm_utils_, builder_) {
294
+ }
295
+
289
296
llvm::Type* LLVMList::get_list_type (llvm::Type* el_type, std::string& type_code,
290
297
int32_t type_size) {
291
298
if ( typecode2listtype.find (type_code) != typecode2listtype.end () ) {
@@ -577,10 +584,86 @@ namespace LFortran {
577
584
are_iterators_set = false ;
578
585
}
579
586
580
- void LLVMDict::linear_probing (llvm::Value* capacity, llvm::Value* key_hash,
581
- llvm::Value* key, llvm::Value* key_list,
582
- llvm::Value* key_mask, llvm::Module& module,
583
- ASR::ttype_t * key_asr_type, bool for_read) {
587
+ void LLVMDict::resolve_collision (
588
+ llvm::Value* capacity, llvm::Value* key_hash,
589
+ llvm::Value* key, llvm::Value* key_list,
590
+ llvm::Value* key_mask, llvm::Module& module,
591
+ ASR::ttype_t * key_asr_type, bool /* for_read*/ ) {
592
+ if ( !are_iterators_set ) {
593
+ pos_ptr = builder->CreateAlloca (llvm::Type::getInt32Ty (context), nullptr );
594
+ is_key_matching_var = builder->CreateAlloca (llvm::Type::getInt1Ty (context), nullptr );
595
+ }
596
+ LLVM::CreateStore (*builder, key_hash, pos_ptr);
597
+
598
+
599
+ llvm::BasicBlock *loophead = llvm::BasicBlock::Create (context, " loop.head" );
600
+ llvm::BasicBlock *loopbody = llvm::BasicBlock::Create (context, " loop.body" );
601
+ llvm::BasicBlock *loopend = llvm::BasicBlock::Create (context, " loop.end" );
602
+
603
+
604
+ // head
605
+ llvm_utils->start_new_block (loophead);
606
+ {
607
+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
608
+ llvm::Value* is_key_set = LLVM::CreateLoad (*builder,
609
+ llvm_utils->create_ptr_gep (key_mask, pos));
610
+ is_key_set = builder->CreateICmpNE (is_key_set,
611
+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 0 )));
612
+ llvm::Value* is_key_matching = llvm::ConstantInt::get (llvm::Type::getInt1Ty (context),
613
+ llvm::APInt (1 , 0 ));
614
+ LLVM::CreateStore (*builder, is_key_matching, is_key_matching_var);
615
+ llvm::Function *fn = builder->GetInsertBlock ()->getParent ();
616
+ llvm::BasicBlock *thenBB = llvm::BasicBlock::Create (context, " then" , fn);
617
+ llvm::BasicBlock *elseBB = llvm::BasicBlock::Create (context, " else" );
618
+ llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create (context, " ifcont" );
619
+ builder->CreateCondBr (is_key_set, thenBB, elseBB);
620
+ builder->SetInsertPoint (thenBB);
621
+ {
622
+ llvm::Value* original_key = llvm_utils->list_api ->read_item (key_list, pos,
623
+ LLVM::is_llvm_struct (key_asr_type), false );
624
+ is_key_matching = llvm_utils->is_equal_by_value (key, original_key, module,
625
+ key_asr_type);
626
+ LLVM::CreateStore (*builder, is_key_matching, is_key_matching_var);
627
+ }
628
+ builder->CreateBr (mergeBB);
629
+
630
+
631
+ llvm_utils->start_new_block (elseBB);
632
+ llvm_utils->start_new_block (mergeBB);
633
+ // TODO: Allow safe exit if pos becomes key_hash again.
634
+ // Ideally should not happen as dict will be resized once
635
+ // load factor touches a threshold (which will always be less than 1)
636
+ // so there will be some key which will not be set. However for safety
637
+ // we can add an exit from the loop with a error message.
638
+ llvm::Value *cond = builder->CreateAnd (is_key_set, builder->CreateNot (
639
+ LLVM::CreateLoad (*builder, is_key_matching_var)));
640
+ builder->CreateCondBr (cond, loopbody, loopend);
641
+ }
642
+
643
+
644
+ // body
645
+ llvm_utils->start_new_block (loopbody);
646
+ {
647
+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
648
+ pos = builder->CreateAdd (pos, llvm::ConstantInt::get (llvm::Type::getInt32Ty (context),
649
+ llvm::APInt (32 , 1 )));
650
+ pos = builder->CreateSRem (pos, capacity);
651
+ LLVM::CreateStore (*builder, pos, pos_ptr);
652
+ }
653
+
654
+
655
+ builder->CreateBr (loophead);
656
+
657
+
658
+ // end
659
+ llvm_utils->start_new_block (loopend);
660
+ }
661
+
662
+ void LLVMDictOptimizedLinearProbing::resolve_collision (
663
+ llvm::Value* capacity, llvm::Value* key_hash,
664
+ llvm::Value* key, llvm::Value* key_list,
665
+ llvm::Value* key_mask, llvm::Module& module,
666
+ ASR::ttype_t * key_asr_type, bool for_read) {
584
667
if ( !are_iterators_set ) {
585
668
if ( !for_read ) {
586
669
pos_ptr = builder->CreateAlloca (llvm::Type::getInt32Ty (context), nullptr );
@@ -648,15 +731,45 @@ namespace LFortran {
648
731
llvm_utils->start_new_block (loopend);
649
732
}
650
733
651
- void LLVMDict::linear_probing_for_write (llvm::Value* dict, llvm::Value* key_hash,
652
- llvm::Value* key, llvm::Value* value,
653
- llvm::Module& module, ASR::ttype_t * key_asr_type,
654
- ASR::ttype_t * value_asr_type) {
734
+ void LLVMDict::resolve_collision_for_write (
735
+ llvm::Value* dict, llvm::Value* key_hash,
736
+ llvm::Value* key, llvm::Value* value,
737
+ llvm::Module& module, ASR::ttype_t * key_asr_type,
738
+ ASR::ttype_t * value_asr_type) {
655
739
llvm::Value* key_list = get_key_list (dict);
656
740
llvm::Value* value_list = get_value_list (dict);
657
741
llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
658
742
llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
659
- linear_probing (capacity, key_hash, key, key_list, key_mask, module, key_asr_type);
743
+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask, module, key_asr_type);
744
+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
745
+ llvm_utils->list_api ->write_item (key_list, pos, key,
746
+ key_asr_type, module, false );
747
+ llvm_utils->list_api ->write_item (value_list, pos, value,
748
+ value_asr_type, module, false );
749
+ llvm::Value* key_mask_value = LLVM::CreateLoad (*builder,
750
+ llvm_utils->create_ptr_gep (key_mask, pos));
751
+ llvm::Value* is_slot_empty = builder->CreateICmpEQ (key_mask_value,
752
+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 0 )));
753
+ llvm::Value* occupancy_ptr = get_pointer_to_occupancy (dict);
754
+ is_slot_empty = builder->CreateZExt (is_slot_empty, llvm::Type::getInt32Ty (context));
755
+ llvm::Value* occupancy = LLVM::CreateLoad (*builder, occupancy_ptr);
756
+ LLVM::CreateStore (*builder, builder->CreateAdd (occupancy, is_slot_empty),
757
+ occupancy_ptr);
758
+ LLVM::CreateStore (*builder,
759
+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )),
760
+ llvm_utils->create_ptr_gep (key_mask, pos));
761
+ }
762
+
763
+ void LLVMDictOptimizedLinearProbing::resolve_collision_for_write (
764
+ llvm::Value* dict, llvm::Value* key_hash,
765
+ llvm::Value* key, llvm::Value* value,
766
+ llvm::Module& module, ASR::ttype_t * key_asr_type,
767
+ ASR::ttype_t * value_asr_type) {
768
+ llvm::Value* key_list = get_key_list (dict);
769
+ llvm::Value* value_list = get_value_list (dict);
770
+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
771
+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
772
+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask, module, key_asr_type);
660
773
llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
661
774
llvm_utils->list_api ->write_item (key_list, pos, key,
662
775
key_asr_type, module, false );
@@ -681,9 +794,24 @@ namespace LFortran {
681
794
LLVM::CreateStore (*builder, set_max_2, llvm_utils->create_ptr_gep (key_mask, pos));
682
795
}
683
796
684
- llvm::Value* LLVMDict::linear_probing_for_read (llvm::Value* dict, llvm::Value* key_hash,
685
- llvm::Value* key, llvm::Module& module,
686
- ASR::ttype_t * key_asr_type) {
797
+ llvm::Value* LLVMDict::resolve_collision_for_read (
798
+ llvm::Value* dict, llvm::Value* key_hash,
799
+ llvm::Value* key, llvm::Module& module,
800
+ ASR::ttype_t * key_asr_type) {
801
+ llvm::Value* key_list = get_key_list (dict);
802
+ llvm::Value* value_list = get_value_list (dict);
803
+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
804
+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
805
+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask, module, key_asr_type);
806
+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
807
+ llvm::Value* item = llvm_utils->list_api ->read_item (value_list, pos, true , false );
808
+ return item;
809
+ }
810
+
811
+ llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read (
812
+ llvm::Value* dict, llvm::Value* key_hash,
813
+ llvm::Value* key, llvm::Module& module,
814
+ ASR::ttype_t * key_asr_type) {
687
815
llvm::Value* key_list = get_key_list (dict);
688
816
llvm::Value* value_list = get_value_list (dict);
689
817
llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
@@ -707,7 +835,7 @@ namespace LFortran {
707
835
builder->CreateBr (mergeBB);
708
836
llvm_utils->start_new_block (elseBB);
709
837
{
710
- linear_probing (capacity, key_hash, key, key_list, key_mask,
838
+ this -> resolve_collision (capacity, key_hash, key, key_list, key_mask,
711
839
module, key_asr_type, true );
712
840
}
713
841
llvm_utils->start_new_block (mergeBB);
@@ -809,7 +937,7 @@ namespace LFortran {
809
937
llvm::Value* value = llvm_utils->list_api ->read_item (value_list, idx,
810
938
LLVM::is_llvm_struct (value_asr_type), false );
811
939
llvm::Value* key_hash = get_key_hash (current_capacity, key, key_asr_type, *module);
812
- linear_probing (current_capacity, key_hash, key, new_key_list,
940
+ this -> resolve_collision (current_capacity, key_hash, key, new_key_list,
813
941
new_key_mask, *module, key_asr_type);
814
942
llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
815
943
llvm::Value* key_dest = llvm_utils->list_api ->read_item (new_key_list, pos,
@@ -883,17 +1011,17 @@ namespace LFortran {
883
1011
rehash_all_at_once_if_needed (dict, module, key_asr_type, value_asr_type);
884
1012
llvm::Value* current_capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
885
1013
llvm::Value* key_hash = get_key_hash (current_capacity, key, key_asr_type, *module);
886
- linear_probing_for_write (dict, key_hash, key, value, *module,
887
- key_asr_type, value_asr_type);
1014
+ this -> resolve_collision_for_write (dict, key_hash, key, value, *module,
1015
+ key_asr_type, value_asr_type);
888
1016
}
889
1017
890
1018
llvm::Value* LLVMDict::read_item (llvm::Value* dict, llvm::Value* key,
891
1019
llvm::Module& module, ASR::ttype_t * key_asr_type,
892
1020
bool get_pointer) {
893
1021
llvm::Value* current_capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
894
1022
llvm::Value* key_hash = get_key_hash (current_capacity, key, key_asr_type, module);
895
- llvm::Value* value_ptr = linear_probing_for_read (dict, key_hash, key, module,
896
- key_asr_type);
1023
+ llvm::Value* value_ptr = this -> resolve_collision_for_read (dict, key_hash, key, module,
1024
+ key_asr_type);
897
1025
if ( get_pointer ) {
898
1026
return value_ptr;
899
1027
}
@@ -921,6 +1049,10 @@ namespace LFortran {
921
1049
return LLVM::CreateLoad (*builder, get_pointer_to_occupancy (dict));
922
1050
}
923
1051
1052
+ LLVMDict::~LLVMDict () {}
1053
+
1054
+ LLVMDictOptimizedLinearProbing::~LLVMDictOptimizedLinearProbing () {}
1055
+
924
1056
void LLVMList::resize_if_needed (llvm::Value* list, llvm::Value* n,
925
1057
llvm::Value* capacity, int32_t type_size,
926
1058
llvm::Type* el_type, llvm::Module& module) {
0 commit comments