Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class GraphBuilder {
VkGraphPtr flatbuffer_;
const uint8_t* constant_data_;

std::unordered_map<uint32_t, ValueRef> ref_mapping_;
std::vector<ValueRef> ref_mapping_;

public:
explicit GraphBuilder(
Expand All @@ -164,22 +164,20 @@ class GraphBuilder {
constant_data_(constant_data),
ref_mapping_() {}

bool fb_id_exists(const uint32_t fb_id) {
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
ref_mapping_.find(fb_id);
void resize(uint32_t size) {
ref_mapping_.resize(size, INT32_MAX);
}

return found_ref != ref_mapping_.end();
bool fb_id_exists(const uint32_t fb_id) {
return fb_id < ref_mapping_.size() && ref_mapping_[fb_id] != INT32_MAX;
}

ValueRef get_fb_id_valueref(const uint32_t fb_id) {
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
ref_mapping_.find(fb_id);

ET_CHECK_MSG(
found_ref != ref_mapping_.end(),
fb_id_exists(fb_id),
"Trying to extract a value that hasn't yet been added to the graph.");

return found_ref->second;
return ref_mapping_[fb_id];
}

void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) {
Expand Down Expand Up @@ -315,6 +313,9 @@ class GraphBuilder {
}

void build_graph() {
// Resize the mapping to the number of values in the flatbuffer
resize(flatbuffer_->values()->size());

// First, add all values to the graph
for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) {
VkValuePtr value = flatbuffer_->values()->Get(fb_id);
Expand Down Expand Up @@ -489,8 +490,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {

VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data);

GraphBuilder builder =
GraphBuilder(compute_graph, flatbuffer_graph, constant_data);
GraphBuilder builder(compute_graph, flatbuffer_graph, constant_data);

builder.build_graph();

Expand Down
Loading