Skip to content
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ void main() {
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
const ivec3 gpos = ivec3(
gl_GlobalInvocationID.x % out_limits_scaled.x,
div_by_x % out_limits_scaled.y,
div_by_x / out_limits_scaled.y);
div_by_x,
gl_GlobalInvocationID.y);

// If the top left position is out of bounds, then this invocation will have
// no work to do.
if (gpos.z >= out_limits.z) {
if (gpos.y >= out_limits_scaled.y || gpos.z >= out_limits.z) {
return;
}

Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,10 @@ void add_conv2d_node(
utils::uvec3 wg_size = create_conv2d_global_wg_size(
graph, method, out, weight_data, stride_equals_dilation);

if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
if (method == Conv2dMethod::Depthwise) {
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
} else if (method == Conv2dMethod::Pointwise) {
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
}

vkapi::ParamsBindList param_buffers;
Expand Down
Loading