Skip to content
13 changes: 12 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,18 @@ void add_native_layer_norm_node(
std::vector<int64_t> in_sizes = t_input->sizes();

utils::uvec3 global_size = t_out->logical_limits();
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
utils::uvec3 local_size;

// Since the shader sets shared memory scale factor > 1, if dispatch is
// greater than maximum WG size. Setting WG size in X axis to max WG size,
// would allow best thread utilization.
if (global_size[0] > 64) {
local_size = {64, 1, 1};
} else {
// If thread size in X axis is smaller or equal to maximum WG size, we can
// let the function decide the best WG size.
local_size = graph.create_local_wg_size(global_size);
}

std::string kernel_name("native_layer_norm");
kernel_name.reserve(kShaderNameReserve);
Expand Down
Loading