Skip to content

Commit 5a997dc

Browse files
committed
[ET-VK] Using vector for storing ref_mapping_ in GraphBuilder to improve model load time and memory.
This diff changes GraphBuilder class to store ref id to value mapping as vector instead of unordered map, since maximum id is known and thus vector can be sized to store the map. Differential Revision: [D73969916](https://our.internmc.facebook.com/intern/diff/D73969916/) ghstack-source-id: 281563018 Pull Request resolved: #10647
1 parent cc6b9c1 commit 5a997dc

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class GraphBuilder {
152152
VkGraphPtr flatbuffer_;
153153
const uint8_t* constant_data_;
154154

155-
std::unordered_map<uint32_t, ValueRef> ref_mapping_;
155+
std::vector<ValueRef> ref_mapping_;
156156

157157
public:
158158
explicit GraphBuilder(
@@ -164,22 +164,20 @@ class GraphBuilder {
164164
constant_data_(constant_data),
165165
ref_mapping_() {}
166166

167-
bool fb_id_exists(const uint32_t fb_id) {
168-
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
169-
ref_mapping_.find(fb_id);
167+
void resize(uint32_t size) {
168+
ref_mapping_.resize(size, INT32_MAX);
169+
}
170170

171-
return found_ref != ref_mapping_.end();
171+
bool fb_id_exists(const uint32_t fb_id) {
172+
return fb_id < ref_mapping_.size() && ref_mapping_[fb_id] != INT32_MAX;
172173
}
173174

174175
ValueRef get_fb_id_valueref(const uint32_t fb_id) {
175-
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
176-
ref_mapping_.find(fb_id);
177-
178176
ET_CHECK_MSG(
179-
found_ref != ref_mapping_.end(),
177+
fb_id_exists(fb_id),
180178
"Trying to extract a value that hasn't yet been added to the graph.");
181179

182-
return found_ref->second;
180+
return ref_mapping_[fb_id];
183181
}
184182

185183
void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) {
@@ -315,6 +313,9 @@ class GraphBuilder {
315313
}
316314

317315
void build_graph() {
316+
// Resize the mapping to the number of values in the flatbuffer
317+
resize(flatbuffer_->values()->size());
318+
318319
// First, add all values to the graph
319320
for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) {
320321
VkValuePtr value = flatbuffer_->values()->Get(fb_id);
@@ -489,8 +490,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
489490

490491
VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data);
491492

492-
GraphBuilder builder =
493-
GraphBuilder(compute_graph, flatbuffer_graph, constant_data);
493+
GraphBuilder builder(compute_graph, flatbuffer_graph, constant_data);
494494

495495
builder.build_graph();
496496

0 commit comments

Comments
 (0)