diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index 552037247fd..e44a41fc9bc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -49,12 +49,12 @@ void main() { const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x; const ivec3 gpos = ivec3( gl_GlobalInvocationID.x % out_limits_scaled.x, - div_by_x % out_limits_scaled.y, - div_by_x / out_limits_scaled.y); + div_by_x, + gl_GlobalInvocationID.y); // If the top left position is out of bounds, then this invocation will have // no work to do. - if (gpos.z >= out_limits.z) { + if (gpos.y >= out_limits_scaled.y || gpos.z >= out_limits.z) { return; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index a0ac58ea9bc..5250c3baef2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -398,8 +398,10 @@ void add_conv2d_node( utils::uvec3 wg_size = create_conv2d_global_wg_size( graph, method, out, weight_data, stride_equals_dilation); - if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) { + if (method == Conv2dMethod::Depthwise) { wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1}; + } else if (method == Conv2dMethod::Pointwise) { + wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; } vkapi::ParamsBindList param_buffers;