From bb9ec6d119dde2fd57c2fea91f0c6e5bd3273d78 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Mon, 5 May 2025 07:09:31 -0700 Subject: [PATCH] [ET-VK] Using vector for storing ref_mapping_ in GraphBuilder to improve model load time and memory. Pull Request resolved: https://github.com/pytorch/executorch/pull/10647 This diff changes GraphBuilder class to store ref id to value mapping as vector instead of unordered map, since maximum id is known and thus vector can be sized to store the map. ghstack-source-id: 282013578 @exported-using-ghexport Differential Revision: [D73969916](https://our.internmc.facebook.com/intern/diff/D73969916/) --- backends/vulkan/runtime/VulkanBackend.cpp | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 21fd137b65b..b32f4eb4308 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -152,7 +152,7 @@ class GraphBuilder { VkGraphPtr flatbuffer_; const uint8_t* constant_data_; - std::unordered_map ref_mapping_; + std::vector ref_mapping_; public: explicit GraphBuilder( @@ -164,22 +164,20 @@ class GraphBuilder { constant_data_(constant_data), ref_mapping_() {} - bool fb_id_exists(const uint32_t fb_id) { - const std::unordered_map::iterator found_ref = - ref_mapping_.find(fb_id); + void resize(uint32_t size) { + ref_mapping_.resize(size, INT32_MAX); + } - return found_ref != ref_mapping_.end(); + bool fb_id_exists(const uint32_t fb_id) { + return fb_id < ref_mapping_.size() && ref_mapping_[fb_id] != INT32_MAX; } ValueRef get_fb_id_valueref(const uint32_t fb_id) { - const std::unordered_map::iterator found_ref = - ref_mapping_.find(fb_id); - ET_CHECK_MSG( - found_ref != ref_mapping_.end(), + fb_id_exists(fb_id), "Trying to extract a value that hasn't yet been added to the graph."); - return found_ref->second; + return ref_mapping_[fb_id]; } void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) { @@ -315,6 +313,9 @@ class GraphBuilder { } void build_graph() { + // Resize the mapping to the number of values in the flatbuffer + resize(flatbuffer_->values()->size()); + // First, add all values to the graph for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) { VkValuePtr value = flatbuffer_->values()->Get(fb_id); @@ -489,8 +490,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data); - GraphBuilder builder = - GraphBuilder(compute_graph, flatbuffer_graph, constant_data); + GraphBuilder builder(compute_graph, flatbuffer_graph, constant_data); builder.build_graph();