diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index 683530adbfd..3428fe49117 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -10,7 +10,7 @@ import os import tempfile from dataclasses import dataclass -from typing import ClassVar, Dict, List, Literal, Optional +from typing import ClassVar, Dict, List, Literal, Optional, Sequence import pkg_resources from executorch.exir._serialize._cord import Cord @@ -18,21 +18,25 @@ from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile from executorch.exir._serialize._program import _insert_flatbuffer_header -from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer +from executorch.exir._serialize.data_serializer import ( + DataPayload, + DataSerializer, + TensorEntry, +) from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required -# Byte order of numbers written to flat tensor headers. Always little-endian -# regardless of the host system, since all commonly-used modern CPUs are little -# endian. -_HEADER_BYTEORDER: Literal["little"] = "little" - from executorch.extension.flat_tensor.serialize.flat_tensor_schema import ( DataSegment, FlatTensor, TensorMetadata, ) +# Byte order of numbers written to flat tensor headers. Always little-endian +# regardless of the host system, since all commonly-used modern CPUs are little +# endian. +_HEADER_BYTEORDER: Literal["little"] = "little" + def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: """Serializes a FlatTensor to a flatbuffer and returns the serialized data.""" @@ -209,6 +213,62 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]: return None +def _extract_tensors( + fqn_to_tensor: Dict[str, TensorEntry], + buffers: Sequence[bytes], + segments: List[Cord], + tensor_alignment: int, +) -> List[TensorMetadata]: + """Places tensors into a single segment, aligned to tensor_alignment within + the segment. + + Args: + fqn_to_tensor: A map from fully qualified names to tensor entries. + buffers: A sequence of tensor buffers. + segments: A list of segments to append the tensor data to. Modified in-place. + tensor_alignment: The alignment of the tensor data. + + Returns: + A list of TensorMetadata, which describes the tensors in the segment. + """ + tensor_data: Cord = Cord() + tensors: List[TensorMetadata] = [] + # {idx, offset} + saved_offsets: Dict[int, int] = {} + for fqn, tensor_entry in fqn_to_tensor.items(): + assert tensor_entry.layout is not None + # Check index into the tensor buffers is valid. + assert tensor_entry.buffer_index < len( + buffers + ), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(buffers)}." + + # Check if the tensor has already been appended to the flat_tensor_data. + offset = saved_offsets.get(tensor_entry.buffer_index, -1) + if offset == -1: + if len(tensor_data) > 0: + # Add padding to round off the previous tensor offset. + pad_length = padding_required(len(tensor_data), tensor_alignment) + tensor_data.append(b"\x00" * pad_length) + # Add to saved offsets. + offset = len(tensor_data) + saved_offsets[tensor_entry.buffer_index] = offset + # Append to flat_tensor_data at the offset. + tensor_data.append(buffers[tensor_entry.buffer_index]) + + tensors.append( + TensorMetadata( + fully_qualified_name=fqn, + scalar_type=tensor_entry.layout.scalar_type, + sizes=tensor_entry.layout.sizes, + dim_order=tensor_entry.layout.dim_order, + segment_index=len(segments), + offset=offset, + ) + ) + segments.append(tensor_data) + return tensors + + class FlatTensorSerializer(DataSerializer): """A concrete implementation of the DataSerializer interface that serializes and deserializes data to/from the FlatTensor format. @@ -227,61 +287,45 @@ def serialize( self, data: DataPayload, ) -> Cord: - """Serializes a list of tensor metadata and tensors into a blob.""" - - flat_tensor_metadata: List[TensorMetadata] = [] - flat_tensor_data: Cord = Cord() - - # {idx, offset} - saved_offsets: Dict[int, int] = {} - - for fqn, tensor_entry in data.fqn_to_tensor.items(): - assert tensor_entry.layout is not None - # Check index into the tensor buffers is valid. - assert tensor_entry.buffer_index < len( - data.buffers - ), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(data.buffers)}." - - # Check if the tensor has already been appended to the flat_tensor_data. - offset = saved_offsets.get(tensor_entry.buffer_index, -1) - if offset == -1: - if len(flat_tensor_data) > 0: - # Add padding to round off the previous tensor offset. - pad_length = padding_required( - len(flat_tensor_data), self.config.tensor_alignment - ) - flat_tensor_data.append(b"\x00" * pad_length) - # Add to saved offsets. - offset = len(flat_tensor_data) - saved_offsets[tensor_entry.buffer_index] = offset - # Append to flat_tensor_data at the offset. - flat_tensor_data.append(data.buffers[tensor_entry.buffer_index]) - - flat_tensor_metadata.append( - TensorMetadata( - fully_qualified_name=fqn, - scalar_type=tensor_entry.layout.scalar_type, - sizes=tensor_entry.layout.sizes, - dim_order=tensor_entry.layout.dim_order, - segment_index=0, - offset=offset, + """Serializes a list of tensors and named data into a blob.""" + + segments: List[Cord] = [] + tensors = _extract_tensors( + data.fqn_to_tensor, + data.buffers, + segments, + self.config.tensor_alignment, + ) + + data_segments: List[DataSegment] = [] + segment_data = Cord() + for segment in segments: + prev_end = ( + (data_segments[-1].offset + data_segments[-1].size) + if data_segments + else 0 + ) + data_segments.append( + DataSegment( + offset=aligned_size(prev_end, self.config.segment_alignment), + size=len(segment), ) ) - - # Pad flat_tensor_data to segment alignment. - segment_pad_length = padding_required( - len(flat_tensor_data), self.config.segment_alignment - ) - if segment_pad_length > 0: - flat_tensor_data.append(b"\x00" * segment_pad_length) + # Pad segment_data to segment alignment. + segment_pad_length = padding_required( + len(segment_data), self.config.segment_alignment + ) + if segment_pad_length > 0: + segment_data.append(b"\x00" * segment_pad_length) + segment_data.append(segment) # Create FlatTensor, which describes of the contents of the file and # points to all the data segments. It will be serialized to flatbuffer. flat_tensor = FlatTensor( version=0, # Keep in sync with c++ version number in serialize.h tensor_alignment=self.config.tensor_alignment, - tensors=flat_tensor_metadata, - segments=[DataSegment(offset=0, size=len(flat_tensor_data))], + tensors=tensors, + segments=data_segments, named_data=[], ) @@ -307,7 +351,7 @@ def serialize( flatbuffer_offset=padded_header_length, flatbuffer_size=len(flatbuffer_payload), segment_base_offset=segment_base_offset, - segment_data_size=len(flat_tensor_data), + segment_data_size=len(segment_data), ).to_bytes() # Pad header and payload to segment alignment. @@ -327,7 +371,7 @@ def serialize( assert eh.flatbuffer_size == original_flatbuffer_payload_size assert eh.segment_base_offset == segment_base_offset assert eh.flatbuffer_offset == padded_header_length - assert eh.segment_data_size == len(flat_tensor_data) + assert eh.segment_data_size == len(segment_data) del header_data del flatbuffer_payload @@ -335,7 +379,7 @@ def serialize( # Place everything into one segment. payload = Cord() payload.append(injected_flatbuffer_data) - payload.append(flat_tensor_data) + payload.append(segment_data) return payload