Skip to content

Commit efced74

Browse files
committed
Perform all reading operations in place
1 parent 62aaa7d commit efced74

File tree

1 file changed

+114
-49
lines changed

1 file changed

+114
-49
lines changed

src/uproot/models/RNTuple.py

Lines changed: 114 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,48 @@ def read_page(
659659
if array_cache is not None:
660660
array_cache[key] = destination.copy()
661661

662+
def _expected_array_length_and_starts(
663+
self, col_idx, cluster_start, cluster_stop, missing_element_padding=0
664+
):
665+
"""
666+
Args:
667+
col_idx (int): The column index.
668+
cluster_start (int): The first cluster to include.
669+
cluster_stop (int): The first cluster to exclude (i.e. one greater than the last cluster to include).
670+
missing_element_padding (int): Number of padding elements to add at the start of the array.
671+
672+
Returns the expected length of the array over the given cluster range, including padding, and also the start indices of each cluster.
673+
"""
674+
field_metadata = self.get_field_metadata(col_idx)
675+
if field_metadata.dtype_byte in uproot.const.rntuple_index_types:
676+
# for offsets we need an extra zero at the start
677+
missing_element_padding += 1
678+
total_length = missing_element_padding
679+
starts = []
680+
for cluster_idx in range(cluster_start, cluster_stop):
681+
linklist = self._ntuple.page_link_list[cluster_idx]
682+
# Check if the column is suppressed and pick the non-suppressed one if so
683+
if col_idx < len(linklist) and linklist[col_idx].suppressed:
684+
rel_crs = self._column_records_dict[
685+
self.column_records[col_idx].field_id
686+
]
687+
col_idx = next(
688+
cr.idx for cr in rel_crs if not linklist[cr.idx].suppressed
689+
)
690+
field_metadata = self.get_field_metadata(
691+
col_idx
692+
) # Update metadata if suppressed
693+
pagelist = (
694+
linklist[field_metadata.ncol].pages
695+
if field_metadata.ncol < len(linklist)
696+
else []
697+
)
698+
cluster_length = sum(desc.num_elements for desc in pagelist)
699+
starts.append(total_length)
700+
total_length += cluster_length
701+
702+
return total_length, starts
703+
662704
def read_cluster_range(
663705
self,
664706
col_idx,
@@ -678,18 +720,27 @@ def read_cluster_range(
678720
Returns a numpy array with the data from the column.
679721
"""
680722
field_metadata = self.get_field_metadata(col_idx)
681-
arrays = [
723+
total_length, starts = self._expected_array_length_and_starts(
724+
col_idx, cluster_start, cluster_stop, missing_element_padding
725+
)
726+
res = numpy.empty(total_length, field_metadata.dtype_result)
727+
# Initialize the padding elements. Note that it might be different from missing_element_padding
728+
# because for offsets there is an extra zero added at the start.
729+
assert len(starts) > 0, "The cluster range is invalid"
730+
res[: starts[0]] = 0
731+
732+
for i, cluster_idx in enumerate(range(cluster_start, cluster_stop)):
733+
stop = starts[i + 1] if i != len(starts) - 1 else total_length
682734
self.read_pages(
683735
cluster_idx,
684736
col_idx,
685737
field_metadata,
738+
destination=res[starts[i] : stop],
686739
array_cache=array_cache,
687740
)
688-
for cluster_idx in range(cluster_start, cluster_stop)
689-
]
690-
res = self.combine_cluster_arrays(
691-
arrays, field_metadata, missing_element_padding
692-
)
741+
742+
self.combine_cluster_arrays(res, starts, field_metadata)
743+
# TODO: Fix type here?
693744

694745
return res
695746

@@ -698,17 +749,17 @@ def read_pages(
698749
cluster_idx,
699750
col_idx,
700751
field_metadata,
752+
destination=None,
701753
array_cache=None,
702754
):
703755
"""
704756
Args:
757+
destination (numpy.ndarray): The array to fill.
705758
cluster_idx (int): The cluster index.
706759
col_idx (int): The column index.
707760
field_metadata (:doc:`uproot.models.RNTuple.FieldClusterMetadata`):
708761
The metadata needed to read the field's pages.
709762
array_cache (None or MutableMapping): Cache of arrays. If None, do not use a cache.
710-
711-
Returns a numpy array with the data from the column.
712763
"""
713764
linklist = self._ntuple.page_link_list[cluster_idx]
714765
# Check if the column is suppressed and pick the non-suppressed one if so
@@ -724,28 +775,38 @@ def read_pages(
724775
else []
725776
)
726777
total_len = numpy.sum([desc.num_elements for desc in pagelist], dtype=int)
727-
res = numpy.empty(total_len, field_metadata.dtype)
778+
if destination is None:
779+
return_buffer = True
780+
destination = numpy.empty(total_len, dtype=field_metadata.dtype)
781+
else:
782+
return_buffer = False
783+
assert len(destination) == total_len
728784

729785
tracker = 0
730786
cumsum = 0
731787
for page_idx, page_desc in enumerate(pagelist):
732788
n_elements = page_desc.num_elements
733789
tracker_end = tracker + n_elements
734790
self.read_page(
735-
res[tracker:tracker_end],
791+
destination[tracker:tracker_end],
736792
cluster_idx,
737793
col_idx,
738794
page_idx,
739795
field_metadata,
740796
array_cache=array_cache,
741797
)
798+
if field_metadata.dtype != field_metadata.dtype_result:
799+
destination[tracker:tracker_end] = destination[
800+
tracker:tracker_end
801+
].view(field_metadata.dtype)[: tracker_end - tracker]
742802
if field_metadata.delta:
743-
res[tracker] -= cumsum
744-
cumsum += numpy.sum(res[tracker:tracker_end])
803+
destination[tracker] -= cumsum
804+
cumsum += numpy.sum(destination[tracker:tracker_end])
745805
tracker = tracker_end
746806

747-
res = self.post_process(res, field_metadata)
748-
return res
807+
self.post_process(destination, field_metadata)
808+
if return_buffer:
809+
return destination
749810

750811
def gpu_read_clusters(self, fields, start_cluster_idx, stop_cluster_idx):
751812
"""
@@ -964,26 +1025,25 @@ def post_process(self, buffer, field_metadata):
9641025
field_metadata (:doc:`uproot.models.RNTuple.FieldClusterMetadata`):
9651026
The metadata needed to post_process buffer.
9661027
967-
Returns post-processed buffer.
1028+
Performs some post-processing on the buffer in place.
9681029
"""
9691030
array_library_string = uproot._util.get_array_library(buffer)
9701031
library = numpy if array_library_string == "numpy" else uproot.extras.cupy()
9711032
if field_metadata.zigzag:
972-
buffer = _from_zigzag(buffer)
1033+
buffer[:] = _from_zigzag(buffer)
9731034
elif field_metadata.delta:
974-
buffer = library.cumsum(buffer)
1035+
buffer[:] = library.cumsum(buffer)
9751036
elif field_metadata.dtype_str == "real32trunc":
976-
buffer = buffer.view(library.float32)
1037+
buffer.dtype = library.float32
9771038
elif field_metadata.dtype_str == "real32quant" and field_metadata.ncol < len(
9781039
self.column_records
9791040
):
9801041
min_value = self.column_records[field_metadata.ncol].min_value
9811042
max_value = self.column_records[field_metadata.ncol].max_value
982-
buffer = min_value + buffer.astype(library.float32) * (
1043+
buffer.dtype = library.float32
1044+
buffer[:] = min_value + buffer.view(library.uint32) * (
9831045
max_value - min_value
9841046
) / ((1 << field_metadata.nbits) - 1)
985-
buffer = buffer.astype(library.float32)
986-
return buffer
9871047

9881048
def deserialize_page_decompressed_buffer(self, destination, field_metadata):
9891049
"""
@@ -1081,6 +1141,28 @@ def get_field_metadata(self, ncol):
10811141
dtype_toread = numpy.dtype("uint8")
10821142
else:
10831143
dtype_toread = dtype
1144+
1145+
rel_crs = self._column_records_dict[self.column_records[ncol].field_id]
1146+
alt_dtype_list = []
1147+
for cr in rel_crs:
1148+
alt_dtype_byte = self.column_records[cr.idx].type
1149+
alt_dtype_str = uproot.const.rntuple_col_num_to_dtype_dict[alt_dtype_byte]
1150+
if alt_dtype_str == "switch":
1151+
alt_dtype = numpy.dtype([("index", "int64"), ("tag", "int32")])
1152+
elif alt_dtype_str == "bit":
1153+
alt_dtype = numpy.dtype("bool")
1154+
elif alt_dtype_byte in uproot.const.rntuple_custom_float_types:
1155+
alt_dtype = numpy.dtype("uint32") # for easier bit manipulation
1156+
else:
1157+
alt_dtype = numpy.dtype(alt_dtype_str)
1158+
alt_dtype_list.append(alt_dtype)
1159+
# We want to skip doing this for strings.
1160+
if self.field_records[self.column_records[ncol].field_id].type_name.startswith(
1161+
"std::string"
1162+
):
1163+
dtype_result = dtype
1164+
else:
1165+
dtype_result = numpy.result_type(*alt_dtype_list)
10841166
field_metadata = FieldClusterMetadata(
10851167
ncol,
10861168
dtype_byte,
@@ -1092,46 +1174,28 @@ def get_field_metadata(self, ncol):
10921174
delta,
10931175
isbit,
10941176
nbits,
1177+
dtype_result,
10951178
)
10961179
return field_metadata
10971180

1098-
def combine_cluster_arrays(self, arrays, field_metadata, missing_element_padding):
1181+
def combine_cluster_arrays(self, array, starts, field_metadata):
10991182
"""
11001183
Args:
1101-
arrays (list): A list of arrays to combine.
1184+
array (numpy.ndarray): An array with the full data.
1185+
starts (list): An array with the start indices of each cluster.
11021186
field_metadata (:doc:`uproot.models.RNTuple.FieldClusterMetadata`):
11031187
The metadata needed to combine arrays.
1104-
missing_element_padding (int): Number of padding elements to add at the start of the array.
11051188
11061189
Returns a field's page arrays concatenated together.
11071190
"""
1108-
array_library_string = uproot._util.get_array_library(arrays[0])
1109-
library = numpy if array_library_string == "numpy" else uproot.extras.cupy()
1110-
11111191
# Check if column stores offset values
11121192
if field_metadata.dtype_byte in uproot.const.rntuple_index_types:
1113-
# Extract the last offset values:
1114-
last_elements = [
1115-
(arr[-1] if len(arr) > 0 else library.zeros((), dtype=arr.dtype))
1116-
for arr in arrays[:-1]
1117-
] # First value always zero, therefore skip first arr.
1118-
last_offsets = library.cumsum(library.array(last_elements))
1119-
for i in range(1, len(arrays)):
1120-
arrays[i] += last_offsets[i - 1]
1121-
1122-
res = library.concatenate(arrays, axis=0)
1123-
1124-
# No longer needed; free memory
1125-
del arrays
1126-
1127-
if field_metadata.dtype_byte in uproot.const.rntuple_index_types:
1128-
# for offsets
1129-
res = numpy.insert(res, 0, 0) if library == numpy else _cupy_insert0(res)
1130-
1131-
if missing_element_padding:
1132-
res = numpy.pad(res, (missing_element_padding, 0))
1133-
1134-
return res
1193+
for i in range(1, len(starts)):
1194+
start = starts[i]
1195+
stop = starts[i + 1] if i != len(starts) - 1 else len(array)
1196+
if start == stop:
1197+
continue
1198+
array[start:stop] += array[start - 1]
11351199

11361200

11371201
def _extract_bits(packed, nbits):
@@ -1772,6 +1836,7 @@ class FieldClusterMetadata:
17721836
delta: bool
17731837
isbit: bool
17741838
nbits: int
1839+
dtype_result: numpy.dtype
17751840

17761841

17771842
@dataclasses.dataclass

0 commit comments

Comments
 (0)