@@ -659,6 +659,48 @@ def read_page(
659659 if array_cache is not None :
660660 array_cache [key ] = destination .copy ()
661661
662+ def _expected_array_length_and_starts (
663+ self , col_idx , cluster_start , cluster_stop , missing_element_padding = 0
664+ ):
665+ """
666+ Args:
667+ col_idx (int): The column index.
668+ cluster_start (int): The first cluster to include.
669+ cluster_stop (int): The first cluster to exclude (i.e. one greater than the last cluster to include).
670+ missing_element_padding (int): Number of padding elements to add at the start of the array.
671+
672+ Returns the expected length of the array over the given cluster range, including padding, and also the start indices of each cluster.
673+ """
674+ field_metadata = self .get_field_metadata (col_idx )
675+ if field_metadata .dtype_byte in uproot .const .rntuple_index_types :
676+ # for offsets we need an extra zero at the start
677+ missing_element_padding += 1
678+ total_length = missing_element_padding
679+ starts = []
680+ for cluster_idx in range (cluster_start , cluster_stop ):
681+ linklist = self ._ntuple .page_link_list [cluster_idx ]
682+ # Check if the column is suppressed and pick the non-suppressed one if so
683+ if col_idx < len (linklist ) and linklist [col_idx ].suppressed :
684+ rel_crs = self ._column_records_dict [
685+ self .column_records [col_idx ].field_id
686+ ]
687+ col_idx = next (
688+ cr .idx for cr in rel_crs if not linklist [cr .idx ].suppressed
689+ )
690+ field_metadata = self .get_field_metadata (
691+ col_idx
692+ ) # Update metadata if suppressed
693+ pagelist = (
694+ linklist [field_metadata .ncol ].pages
695+ if field_metadata .ncol < len (linklist )
696+ else []
697+ )
698+ cluster_length = sum (desc .num_elements for desc in pagelist )
699+ starts .append (total_length )
700+ total_length += cluster_length
701+
702+ return total_length , starts
703+
662704 def read_cluster_range (
663705 self ,
664706 col_idx ,
@@ -678,18 +720,27 @@ def read_cluster_range(
678720 Returns a numpy array with the data from the column.
679721 """
680722 field_metadata = self .get_field_metadata (col_idx )
681- arrays = [
723+ total_length , starts = self ._expected_array_length_and_starts (
724+ col_idx , cluster_start , cluster_stop , missing_element_padding
725+ )
726+ res = numpy .empty (total_length , field_metadata .dtype_result )
727+ # Initialize the padding elements. Note that it might be different from missing_element_padding
728+ # because for offsets there is an extra zero added at the start.
729+ assert len (starts ) > 0 , "The cluster range is invalid"
730+ res [: starts [0 ]] = 0
731+
732+ for i , cluster_idx in enumerate (range (cluster_start , cluster_stop )):
733+ stop = starts [i + 1 ] if i != len (starts ) - 1 else total_length
682734 self .read_pages (
683735 cluster_idx ,
684736 col_idx ,
685737 field_metadata ,
738+ destination = res [starts [i ] : stop ],
686739 array_cache = array_cache ,
687740 )
688- for cluster_idx in range (cluster_start , cluster_stop )
689- ]
690- res = self .combine_cluster_arrays (
691- arrays , field_metadata , missing_element_padding
692- )
741+
742+ self .combine_cluster_arrays (res , starts , field_metadata )
743+ # TODO: Fix type here?
693744
694745 return res
695746
@@ -698,17 +749,17 @@ def read_pages(
698749 cluster_idx ,
699750 col_idx ,
700751 field_metadata ,
752+ destination = None ,
701753 array_cache = None ,
702754 ):
703755 """
704756 Args:
757+ destination (numpy.ndarray): The array to fill.
705758 cluster_idx (int): The cluster index.
706759 col_idx (int): The column index.
707760 field_metadata (:doc:`uproot.models.RNTuple.FieldClusterMetadata`):
708761 The metadata needed to read the field's pages.
709762 array_cache (None or MutableMapping): Cache of arrays. If None, do not use a cache.
710-
711- Returns a numpy array with the data from the column.
712763 """
713764 linklist = self ._ntuple .page_link_list [cluster_idx ]
714765 # Check if the column is suppressed and pick the non-suppressed one if so
@@ -724,28 +775,38 @@ def read_pages(
724775 else []
725776 )
726777 total_len = numpy .sum ([desc .num_elements for desc in pagelist ], dtype = int )
727- res = numpy .empty (total_len , field_metadata .dtype )
778+ if destination is None :
779+ return_buffer = True
780+ destination = numpy .empty (total_len , dtype = field_metadata .dtype )
781+ else :
782+ return_buffer = False
783+ assert len (destination ) == total_len
728784
729785 tracker = 0
730786 cumsum = 0
731787 for page_idx , page_desc in enumerate (pagelist ):
732788 n_elements = page_desc .num_elements
733789 tracker_end = tracker + n_elements
734790 self .read_page (
735- res [tracker :tracker_end ],
791+ destination [tracker :tracker_end ],
736792 cluster_idx ,
737793 col_idx ,
738794 page_idx ,
739795 field_metadata ,
740796 array_cache = array_cache ,
741797 )
798+ if field_metadata .dtype != field_metadata .dtype_result :
799+ destination [tracker :tracker_end ] = destination [
800+ tracker :tracker_end
801+ ].view (field_metadata .dtype )[: tracker_end - tracker ]
742802 if field_metadata .delta :
743- res [tracker ] -= cumsum
744- cumsum += numpy .sum (res [tracker :tracker_end ])
803+ destination [tracker ] -= cumsum
804+ cumsum += numpy .sum (destination [tracker :tracker_end ])
745805 tracker = tracker_end
746806
747- res = self .post_process (res , field_metadata )
748- return res
807+ self .post_process (destination , field_metadata )
808+ if return_buffer :
809+ return destination
749810
750811 def gpu_read_clusters (self , fields , start_cluster_idx , stop_cluster_idx ):
751812 """
@@ -964,26 +1025,25 @@ def post_process(self, buffer, field_metadata):
9641025 field_metadata (:doc:`uproot.models.RNTuple.FieldClusterMetadata`):
9651026 The metadata needed to post_process buffer.
9661027
967- Returns post-processed buffer.
1028+ Performs some post-processing on the buffer in place .
9681029 """
9691030 array_library_string = uproot ._util .get_array_library (buffer )
9701031 library = numpy if array_library_string == "numpy" else uproot .extras .cupy ()
9711032 if field_metadata .zigzag :
972- buffer = _from_zigzag (buffer )
1033+ buffer [:] = _from_zigzag (buffer )
9731034 elif field_metadata .delta :
974- buffer = library .cumsum (buffer )
1035+ buffer [:] = library .cumsum (buffer )
9751036 elif field_metadata .dtype_str == "real32trunc" :
976- buffer = buffer . view ( library .float32 )
1037+ buffer . dtype = library .float32
9771038 elif field_metadata .dtype_str == "real32quant" and field_metadata .ncol < len (
9781039 self .column_records
9791040 ):
9801041 min_value = self .column_records [field_metadata .ncol ].min_value
9811042 max_value = self .column_records [field_metadata .ncol ].max_value
982- buffer = min_value + buffer .astype (library .float32 ) * (
1043+ buffer .dtype = library .float32
1044+ buffer [:] = min_value + buffer .view (library .uint32 ) * (
9831045 max_value - min_value
9841046 ) / ((1 << field_metadata .nbits ) - 1 )
985- buffer = buffer .astype (library .float32 )
986- return buffer
9871047
9881048 def deserialize_page_decompressed_buffer (self , destination , field_metadata ):
9891049 """
@@ -1081,6 +1141,28 @@ def get_field_metadata(self, ncol):
10811141 dtype_toread = numpy .dtype ("uint8" )
10821142 else :
10831143 dtype_toread = dtype
1144+
1145+ rel_crs = self ._column_records_dict [self .column_records [ncol ].field_id ]
1146+ alt_dtype_list = []
1147+ for cr in rel_crs :
1148+ alt_dtype_byte = self .column_records [cr .idx ].type
1149+ alt_dtype_str = uproot .const .rntuple_col_num_to_dtype_dict [alt_dtype_byte ]
1150+ if alt_dtype_str == "switch" :
1151+ alt_dtype = numpy .dtype ([("index" , "int64" ), ("tag" , "int32" )])
1152+ elif alt_dtype_str == "bit" :
1153+ alt_dtype = numpy .dtype ("bool" )
1154+ elif alt_dtype_byte in uproot .const .rntuple_custom_float_types :
1155+ alt_dtype = numpy .dtype ("uint32" ) # for easier bit manipulation
1156+ else :
1157+ alt_dtype = numpy .dtype (alt_dtype_str )
1158+ alt_dtype_list .append (alt_dtype )
1159+ # We want to skip doing this for strings.
1160+ if self .field_records [self .column_records [ncol ].field_id ].type_name .startswith (
1161+ "std::string"
1162+ ):
1163+ dtype_result = dtype
1164+ else :
1165+ dtype_result = numpy .result_type (* alt_dtype_list )
10841166 field_metadata = FieldClusterMetadata (
10851167 ncol ,
10861168 dtype_byte ,
@@ -1092,46 +1174,28 @@ def get_field_metadata(self, ncol):
10921174 delta ,
10931175 isbit ,
10941176 nbits ,
1177+ dtype_result ,
10951178 )
10961179 return field_metadata
10971180
1098- def combine_cluster_arrays (self , arrays , field_metadata , missing_element_padding ):
1181+ def combine_cluster_arrays (self , array , starts , field_metadata ):
10991182 """
11001183 Args:
1101- arrays (list): A list of arrays to combine.
1184+ array (numpy.ndarray): An array with the full data.
1185+ starts (list): An array with the start indices of each cluster.
11021186 field_metadata (:doc:`uproot.models.RNTuple.FieldClusterMetadata`):
11031187 The metadata needed to combine arrays.
1104- missing_element_padding (int): Number of padding elements to add at the start of the array.
11051188
11061189 Returns a field's page arrays concatenated together.
11071190 """
1108- array_library_string = uproot ._util .get_array_library (arrays [0 ])
1109- library = numpy if array_library_string == "numpy" else uproot .extras .cupy ()
1110-
11111191 # Check if column stores offset values
11121192 if field_metadata .dtype_byte in uproot .const .rntuple_index_types :
1113- # Extract the last offset values:
1114- last_elements = [
1115- (arr [- 1 ] if len (arr ) > 0 else library .zeros ((), dtype = arr .dtype ))
1116- for arr in arrays [:- 1 ]
1117- ] # First value always zero, therefore skip first arr.
1118- last_offsets = library .cumsum (library .array (last_elements ))
1119- for i in range (1 , len (arrays )):
1120- arrays [i ] += last_offsets [i - 1 ]
1121-
1122- res = library .concatenate (arrays , axis = 0 )
1123-
1124- # No longer needed; free memory
1125- del arrays
1126-
1127- if field_metadata .dtype_byte in uproot .const .rntuple_index_types :
1128- # for offsets
1129- res = numpy .insert (res , 0 , 0 ) if library == numpy else _cupy_insert0 (res )
1130-
1131- if missing_element_padding :
1132- res = numpy .pad (res , (missing_element_padding , 0 ))
1133-
1134- return res
1193+ for i in range (1 , len (starts )):
1194+ start = starts [i ]
1195+ stop = starts [i + 1 ] if i != len (starts ) - 1 else len (array )
1196+ if start == stop :
1197+ continue
1198+ array [start :stop ] += array [start - 1 ]
11351199
11361200
11371201def _extract_bits (packed , nbits ):
@@ -1772,6 +1836,7 @@ class FieldClusterMetadata:
17721836 delta : bool
17731837 isbit : bool
17741838 nbits : int
1839+ dtype_result : numpy .dtype
17751840
17761841
17771842@dataclasses .dataclass
0 commit comments