Skip to content

Refactor serialize.py #9579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 101 additions & 57 deletions extension/flat_tensor/serialize/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,33 @@
import os
import tempfile
from dataclasses import dataclass
from typing import ClassVar, Dict, List, Literal, Optional
from typing import ClassVar, Dict, List, Literal, Optional, Sequence

import pkg_resources
from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass

from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
from executorch.exir._serialize._program import _insert_flatbuffer_header
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
from executorch.exir._serialize.data_serializer import (
DataPayload,
DataSerializer,
TensorEntry,
)

from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required

# Byte order of numbers written to flat tensor headers. Always little-endian
# regardless of the host system, since all commonly-used modern CPUs are little
# endian.
_HEADER_BYTEORDER: Literal["little"] = "little"

from executorch.extension.flat_tensor.serialize.flat_tensor_schema import (
DataSegment,
FlatTensor,
TensorMetadata,
)

# Byte order of numbers written to flat tensor headers. Always little-endian
# regardless of the host system, since all commonly-used modern CPUs are little
# endian.
_HEADER_BYTEORDER: Literal["little"] = "little"


def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
"""Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
Expand Down Expand Up @@ -209,6 +213,62 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
return None


def _extract_tensors(
fqn_to_tensor: Dict[str, TensorEntry],
buffers: Sequence[bytes],
segments: List[Cord],
tensor_alignment: int,
) -> List[TensorMetadata]:
"""Places tensors into a single segment, aligned to tensor_alignment within
the segment.

Args:
fqn_to_tensor: A map from fully qualified names to tensor entries.
buffers: A sequence of tensor buffers.
segments: A list of segments to append the tensor data to. Modified in-place.
tensor_alignment: The alignment of the tensor data.

Returns:
A list of TensorMetadata, which describes the tensors in the segment.
"""
tensor_data: Cord = Cord()
tensors: List[TensorMetadata] = []
# {idx, offset}
saved_offsets: Dict[int, int] = {}
for fqn, tensor_entry in fqn_to_tensor.items():
assert tensor_entry.layout is not None
# Check index into the tensor buffers is valid.
assert tensor_entry.buffer_index < len(
buffers
), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(buffers)}."

# Check if the tensor has already been appended to the flat_tensor_data.
offset = saved_offsets.get(tensor_entry.buffer_index, -1)
if offset == -1:
if len(tensor_data) > 0:
# Add padding to round off the previous tensor offset.
pad_length = padding_required(len(tensor_data), tensor_alignment)
tensor_data.append(b"\x00" * pad_length)
# Add to saved offsets.
offset = len(tensor_data)
saved_offsets[tensor_entry.buffer_index] = offset
# Append to flat_tensor_data at the offset.
tensor_data.append(buffers[tensor_entry.buffer_index])

tensors.append(
TensorMetadata(
fully_qualified_name=fqn,
scalar_type=tensor_entry.layout.scalar_type,
sizes=tensor_entry.layout.sizes,
dim_order=tensor_entry.layout.dim_order,
segment_index=len(segments),
offset=offset,
)
)
segments.append(tensor_data)
return tensors


class FlatTensorSerializer(DataSerializer):
"""A concrete implementation of the DataSerializer interface that
serializes and deserializes data to/from the FlatTensor format.
Expand All @@ -227,61 +287,45 @@ def serialize(
self,
data: DataPayload,
) -> Cord:
"""Serializes a list of tensor metadata and tensors into a blob."""

flat_tensor_metadata: List[TensorMetadata] = []
flat_tensor_data: Cord = Cord()

# {idx, offset}
saved_offsets: Dict[int, int] = {}

for fqn, tensor_entry in data.fqn_to_tensor.items():
assert tensor_entry.layout is not None
# Check index into the tensor buffers is valid.
assert tensor_entry.buffer_index < len(
data.buffers
), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(data.buffers)}."

# Check if the tensor has already been appended to the flat_tensor_data.
offset = saved_offsets.get(tensor_entry.buffer_index, -1)
if offset == -1:
if len(flat_tensor_data) > 0:
# Add padding to round off the previous tensor offset.
pad_length = padding_required(
len(flat_tensor_data), self.config.tensor_alignment
)
flat_tensor_data.append(b"\x00" * pad_length)
# Add to saved offsets.
offset = len(flat_tensor_data)
saved_offsets[tensor_entry.buffer_index] = offset
# Append to flat_tensor_data at the offset.
flat_tensor_data.append(data.buffers[tensor_entry.buffer_index])

flat_tensor_metadata.append(
TensorMetadata(
fully_qualified_name=fqn,
scalar_type=tensor_entry.layout.scalar_type,
sizes=tensor_entry.layout.sizes,
dim_order=tensor_entry.layout.dim_order,
segment_index=0,
offset=offset,
"""Serializes a list of tensors and named data into a blob."""

segments: List[Cord] = []
tensors = _extract_tensors(
data.fqn_to_tensor,
data.buffers,
segments,
self.config.tensor_alignment,
)

data_segments: List[DataSegment] = []
segment_data = Cord()
for segment in segments:
prev_end = (
(data_segments[-1].offset + data_segments[-1].size)
if data_segments
else 0
)
data_segments.append(
DataSegment(
offset=aligned_size(prev_end, self.config.segment_alignment),
size=len(segment),
)
)

# Pad flat_tensor_data to segment alignment.
segment_pad_length = padding_required(
len(flat_tensor_data), self.config.segment_alignment
)
if segment_pad_length > 0:
flat_tensor_data.append(b"\x00" * segment_pad_length)
# Pad segment_data to segment alignment.
segment_pad_length = padding_required(
len(segment_data), self.config.segment_alignment
)
if segment_pad_length > 0:
segment_data.append(b"\x00" * segment_pad_length)
segment_data.append(segment)

# Create FlatTensor, which describes of the contents of the file and
# points to all the data segments. It will be serialized to flatbuffer.
flat_tensor = FlatTensor(
version=0, # Keep in sync with c++ version number in serialize.h
tensor_alignment=self.config.tensor_alignment,
tensors=flat_tensor_metadata,
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
tensors=tensors,
segments=data_segments,
named_data=[],
)

Expand All @@ -307,7 +351,7 @@ def serialize(
flatbuffer_offset=padded_header_length,
flatbuffer_size=len(flatbuffer_payload),
segment_base_offset=segment_base_offset,
segment_data_size=len(flat_tensor_data),
segment_data_size=len(segment_data),
).to_bytes()

# Pad header and payload to segment alignment.
Expand All @@ -327,15 +371,15 @@ def serialize(
assert eh.flatbuffer_size == original_flatbuffer_payload_size
assert eh.segment_base_offset == segment_base_offset
assert eh.flatbuffer_offset == padded_header_length
assert eh.segment_data_size == len(flat_tensor_data)
assert eh.segment_data_size == len(segment_data)

del header_data
del flatbuffer_payload

# Place everything into one segment.
payload = Cord()
payload.append(injected_flatbuffer_data)
payload.append(flat_tensor_data)
payload.append(segment_data)

return payload

Expand Down
Loading