Skip to content

[ET-VK] Using vector for storing ref_mapping_ in GraphBuilder to improve model load time and memory. #10647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/trivedivivek/81/base
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class GraphBuilder {
VkGraphPtr flatbuffer_;
const uint8_t* constant_data_;

std::unordered_map<uint32_t, ValueRef> ref_mapping_;
std::vector<ValueRef> ref_mapping_;

public:
explicit GraphBuilder(
Expand All @@ -164,22 +164,20 @@ class GraphBuilder {
constant_data_(constant_data),
ref_mapping_() {}

bool fb_id_exists(const uint32_t fb_id) {
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
ref_mapping_.find(fb_id);
void resize(uint32_t size) {
ref_mapping_.resize(size, INT32_MAX);
}

return found_ref != ref_mapping_.end();
bool fb_id_exists(const uint32_t fb_id) {
return fb_id < ref_mapping_.size() && ref_mapping_[fb_id] != INT32_MAX;
}

ValueRef get_fb_id_valueref(const uint32_t fb_id) {
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
ref_mapping_.find(fb_id);

ET_CHECK_MSG(
found_ref != ref_mapping_.end(),
fb_id_exists(fb_id),
"Trying to extract a value that hasn't yet been added to the graph.");

return found_ref->second;
return ref_mapping_[fb_id];
}

void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) {
Expand Down Expand Up @@ -315,6 +313,9 @@ class GraphBuilder {
}

void build_graph() {
// Resize the mapping to the number of values in the flatbuffer
resize(flatbuffer_->values()->size());

// First, add all values to the graph
for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) {
VkValuePtr value = flatbuffer_->values()->Get(fb_id);
Expand Down Expand Up @@ -489,8 +490,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {

VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data);

GraphBuilder builder =
GraphBuilder(compute_graph, flatbuffer_graph, constant_data);
GraphBuilder builder(compute_graph, flatbuffer_graph, constant_data);

builder.build_graph();

Expand Down
Loading