From ad715d9eb1d6da026664c452d9f2654f0805ea69 Mon Sep 17 00:00:00 2001 From: lucylq Date: Mon, 21 Apr 2025 12:01:37 -0700 Subject: [PATCH] Update flat tensor ndm to account for named delegate data Currently flat_tensor ndm only accounts for tensors in get_data, get_num_keys, get_key functions. Add support to return named_data values as well. TODO: consolidate tensors and named_data into one structure in the flatbuffer. This will simplify all the serialization and runtime code. Currently, we assume that a PTD file has either tensors or named_data, not both. After the consolidation, this won't be an issue. Differential Revision: [D73380805](https://our.internmc.facebook.com/intern/diff/D73380805/) [ghstack-poisoned] --- .../flat_tensor/flat_tensor_data_map.cpp | 70 +++++++++++++- .../test/flat_tensor_data_map_test.cpp | 91 ++++++++++++++++--- extension/flat_tensor/test/targets.bzl | 2 +- test/models/export_delegated_program.py | 5 + test/models/targets.bzl | 1 + 5 files changed, 151 insertions(+), 18 deletions(-) diff --git a/extension/flat_tensor/flat_tensor_data_map.cpp b/extension/flat_tensor/flat_tensor_data_map.cpp index 8aa0af13928..2690ea8a6a4 100644 --- a/extension/flat_tensor/flat_tensor_data_map.cpp +++ b/extension/flat_tensor/flat_tensor_data_map.cpp @@ -65,6 +65,25 @@ Result get_flat_tensor_metadata( return Error::NotFound; } +Result get_named_data( + const char* key, + const flatbuffers::Vector< + flatbuffers::Offset>* named_data) { + // Linear search by name. + for (int i = 0; i < named_data->size(); i++) { + if (std::strcmp(named_data->Get(i)->key()->c_str(), key) == 0) { + const auto* metadata = named_data->Get(i); + ET_CHECK_OR_RETURN_ERROR( + metadata->segment_index() >= 0, + InvalidExternalData, + "Invalid segment_index %d; malformed PTD file.", + metadata->segment_index()); + return metadata; + } + } + return Error::NotFound; +} + Result create_tensor_layout( const flat_tensor_flatbuffer::TensorMetadata* tensor_metadata) { ScalarType scalar_type = @@ -109,6 +128,39 @@ ET_NODISCARD Result FlatTensorDataMap::get_metadata( ET_NODISCARD Result FlatTensorDataMap::get_data( const char* key) const { + // TODO(lfq): consolidate named_data and tensors. + // Check named data. + Result named_data = + get_named_data(key, flat_tensor_->named_data()); + if (named_data.ok()) { + size_t segment_index = named_data.get()->segment_index(); + ET_CHECK_OR_RETURN_ERROR( + segment_index < flat_tensor_->segments()->size(), + InvalidExternalData, + "Invalid segment_index %zu; malformed PTD file.", + segment_index); + + size_t segment_offset = + flat_tensor_->segments()->Get(segment_index)->offset(); + size_t segment_size = flat_tensor_->segments()->Get(segment_index)->size(); + ET_CHECK_OR_RETURN_ERROR( + segment_offset < + header_.segment_base_offset + header_.segment_data_size, + InvalidExternalData, + "Invalid segment offset %zu is larger than the segment_base_offset + segment_data_size %" PRIu64 + "; malformed PTD file.", + segment_offset, + header_.segment_base_offset + header_.segment_data_size); + return loader_->load( + /*offset=*/header_.segment_base_offset + segment_offset, + segment_size, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); + } + if (named_data.error() != Error::NotFound) { + return named_data.error(); + } + + // Check tensors, if named data is not found. Result metadata = get_flat_tensor_metadata(key, flat_tensor_->tensors()); if (!metadata.ok()) { @@ -179,16 +231,28 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into( } ET_NODISCARD Result FlatTensorDataMap::get_num_keys() const { - return flat_tensor_->tensors()->size(); + // TODO(lfq): consolidate named_data and tensors. + return flat_tensor_->tensors()->size() + flat_tensor_->named_data()->size(); } ET_NODISCARD Result FlatTensorDataMap::get_key( size_t index) const { - if (index < 0 || index >= flat_tensor_->tensors()->size()) { + // TODO(lfq): consolidate named_data and tensors. + // Currently, this assumes we either have tensors or named_data, but not both. + if (flat_tensor_->tensors()->size() > 0 && flat_tensor_->named_data()->size() > 0) { + return Error::NotImplemented; + } + if (index < 0) { return Error::InvalidArgument; } + if (index < flat_tensor_->tensors()->size()) { + return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str(); + } + if (index < flat_tensor_->named_data()->size()) { + return flat_tensor_->named_data()->Get(index)->key()->c_str(); + } - return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str(); + return Error::InvalidArgument; } /* static */ Result FlatTensorDataMap::load( diff --git a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp index ac4583eda88..855c74ba86b 100644 --- a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp +++ b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp @@ -28,32 +28,36 @@ using torch::executor::util::FileDataLoader; class FlatTensorDataMapTest : public ::testing::Test { protected: + void create_loader(const char* path, const char* module_name) { + // Create a loader for the serialized data map. + Result loader = FileDataLoader::from(path); + ASSERT_EQ(loader.error(), Error::Ok); + loaders_.insert( + {module_name, + std::make_unique(std::move(loader.get()))}); + } void SetUp() override { // Since these tests cause ET_LOG to be called, the PAL must be initialized // first. executorch::runtime::runtime_init(); - // Load data map. The eager linear model is defined at: - // //executorch/test/models/linear_model.py - const char* path = std::getenv("ET_MODULE_LINEAR_DATA_PATH"); - Result loader = FileDataLoader::from(path); - ASSERT_EQ(loader.error(), Error::Ok); - - data_map_loader_ = - std::make_unique(std::move(loader.get())); + // Model defined in //executorch/test/models/linear_model.py + create_loader(std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear"); + // Model defined in //executorch/test/models/export_delegated_program.py + create_loader(std::getenv("ET_MODULE_LINEAR_XNN_DATA_PATH"), "linear_xnn"); } - std::unique_ptr data_map_loader_; + std::unordered_map> loaders_; }; TEST_F(FlatTensorDataMapTest, LoadFlatTensorDataMap) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); + FlatTensorDataMap::load(loaders_["linear"].get()); EXPECT_EQ(data_map.error(), Error::Ok); } TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); + FlatTensorDataMap::load(loaders_["linear"].get()); EXPECT_EQ(data_map.error(), Error::Ok); // Check tensor layouts are correct. @@ -95,7 +99,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) { TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); + FlatTensorDataMap::load(loaders_["linear"].get()); EXPECT_EQ(data_map.error(), Error::Ok); // Check tensor data sizes are correct. @@ -116,7 +120,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) { TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); + FlatTensorDataMap::load(loaders_["linear"].get()); EXPECT_EQ(data_map.error(), Error::Ok); // Check num tensors is 2. @@ -140,7 +144,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) { TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) { Result data_map = - FlatTensorDataMap::load(data_map_loader_.get()); + FlatTensorDataMap::load(loaders_["linear"].get()); EXPECT_EQ(data_map.error(), Error::Ok); // get the metadata @@ -160,3 +164,62 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) { } free(data); } + +TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData_Xnnpack) { + Result data_map = + FlatTensorDataMap::load(loaders_["linear_xnn"].get()); + EXPECT_EQ(data_map.error(), Error::Ok); + + // Check tensor data sizes are correct. + // 64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885 is the + // hash of the 3*3 identity matrix + Result data_weight_res = data_map->get_data( + "64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885"); + ASSERT_EQ(Error::Ok, data_weight_res.error()); + FreeableBuffer data_a = std::move(data_weight_res.get()); + EXPECT_EQ(data_a.size(), 36); // 3*3*4 (3*3 matrix, 4 bytes per float) + + // 15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b is the + // hash of the 3*1 vector [1, 1, 1] + Result data_bias_res = data_map->get_data( + "15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b"); + ASSERT_EQ(Error::Ok, data_bias_res.error()); + FreeableBuffer data_b = std::move(data_bias_res.get()); + EXPECT_EQ(data_b.size(), 12); // 3*4 (3*1 vector, 4 bytes per float) + + // Check get_data fails when key is not found. + Result data_c_res = data_map->get_data("c"); + EXPECT_EQ(data_c_res.error(), Error::NotFound); +} + +TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys_Xnnpack) { + Result data_map = + FlatTensorDataMap::load(loaders_["linear_xnn"].get()); + EXPECT_EQ(data_map.error(), Error::Ok); + + // Check num tensors is 2. + Result num_tensors_res = data_map->get_num_keys(); + ASSERT_EQ(Error::Ok, num_tensors_res.error()); + EXPECT_EQ(num_tensors_res.get(), 2); + + // Check get_key returns the correct keys. + Result key0_res = data_map->get_key(0); + ASSERT_EQ(Error::Ok, key0_res.error()); + EXPECT_EQ( + strcmp( + key0_res.get(), + "64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885"), + 0); + + Result key1_res = data_map->get_key(1); + ASSERT_EQ(Error::Ok, key1_res.error()); + EXPECT_EQ( + strcmp( + key1_res.get(), + "15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b"), + 0); + + // Check get_key fails when out of bounds. + Result key2_res = data_map->get_key(2); + EXPECT_EQ(key2_res.error(), Error::InvalidArgument); +} diff --git a/extension/flat_tensor/test/targets.bzl b/extension/flat_tensor/test/targets.bzl index a2b96526ab5..bd81c937f8c 100644 --- a/extension/flat_tensor/test/targets.bzl +++ b/extension/flat_tensor/test/targets.bzl @@ -35,8 +35,8 @@ def define_common_targets(is_fbcode=False): # The tests use this var to find the program file to load. This uses # an fbcode target path because the authoring/export tools # intentionally don't work in xplat (since they're host-only tools). - "ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])", "ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", + "ET_MODULE_LINEAR_XNN_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_data[ModuleLinear-e.ptd])", } runtime.cxx_test( diff --git a/test/models/export_delegated_program.py b/test/models/export_delegated_program.py index f23d12c2b1d..ff76ea4c80a 100644 --- a/test/models/export_delegated_program.py +++ b/test/models/export_delegated_program.py @@ -99,6 +99,11 @@ class ModuleLinear(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 3) + # Make the linear deterministic. + self.linear.weight.data = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + ) # 3x3 identity matrix + self.linear.bias.data = torch.tensor([0.0, 0.0, 0.0]) def forward(self, x: torch.Tensor): return self.linear(x) diff --git a/test/models/targets.bzl b/test/models/targets.bzl index d620a14f322..c5ff67222a0 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -222,6 +222,7 @@ def define_common_targets(): default_outs = ["."], visibility = [ "//executorch/runtime/executor/test/...", + "//executorch/extension/flat_tensor/test/...", "//executorch/test/...", ], )