Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ void main() {
int kx = 0;
for (int y = start.y; y < end.y; y += params.dilation.y) {
for (int x = start.x; x < end.x; x += params.dilation.x) {
// The weight kernel was rearranged so that every NxN filter is flattened
// to fits in one row. Each filter was then stacked on top of each other
// vertically.
// The weight kernel was rearranged such that every NxN filter is
// flattened to fit in one row. Each filter was then stacked on top of
// each other vertically.
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum);
++kx;
Expand Down
83 changes: 83 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#include "indexing_utils.h"

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;

layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents {
uvec4 data;
}
out_extents;

layout(set = 0, binding = 5) uniform PRECISION restrict InExtents {
uvec4 data;
}
in_extents;

layout(set = 0, binding = 6) uniform PRECISION restrict Params {
ivec2 kernel_size;
ivec2 stride;
ivec2 padding;
ivec2 dilation;
}
params;

// If fields are separated, SwiftShader cannot identify in_group_size.
layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams {
ivec2 overlay_region;
int in_group_size;
}
extra_params;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
return;
}

// Compute the index of the top-left element of the overlay region. Negative
// indices indicate that the top-left element is in a region added by padding.
const ivec2 ipos = pos.xy * params.stride - params.padding;

// Compute the start and end of the input indices to load. Padding is assumed
// to be constant 0 padding, so any reads from the padding region is skipped.
const ivec2 start = ipos;
const ivec2 end = ipos + extra_params.overlay_region.xy;

${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
int kx = 0;
for (int y = start.y, i = 0; i < ${TILE_SIZE}; y += params.dilation.y, i++) {
for (int x = start.x, j = 0; j < ${TILE_SIZE}; x += params.dilation.x, j++) {
// The weight kernel was rearranged such that every NxN filter is
// flattened to fit in one row. Each filter was then stacked on top of
// each other vertically.
const vec4 in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum);
kx++;
}
}

imageStore(image_out, pos, sum);
}
21 changes: 21 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

conv2d_dw_output_tile:
parameter_names_with_default_values:
NDIM: 3
DTYPE: float
TILE_SIZE: 3
generate_variant_forall:
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
shader_variants:
- NAME: conv2d_dw_output_tile_3x3
- NAME: conv2d_dw_output_tile_5x5
TILE_SIZE: 5
33 changes: 22 additions & 11 deletions backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ void resize_conv2d_node(
if (ndim == 4) {
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
}
const auto weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes;
const auto& weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes;
new_out_sizes.at(ndim - 3) =
transposed ? weight_sizes.at(ndim - 3) : weight_sizes.at(ndim - 4);

// Height, Width
const auto new_out_sizes_hw = calc_out_sizes_hw(
const auto& new_out_sizes_hw = calc_out_sizes_hw(
*graph,
self.sizes(),
extra_args[0],
Expand Down Expand Up @@ -87,13 +87,24 @@ enum class Conv2dMethod : uint8_t {
};

api::ShaderInfo get_conv2d_shader(
ComputeGraph& graph,
const vTensor& t_out,
const bool prepack_weights,
const Conv2dMethod method) {
const Conv2dMethod method,
const ValueRef weight) {
std::stringstream kernel_name;
switch (method) {
case Conv2dMethod::Depthwise:
kernel_name << "conv2d_dw";
if (!prepack_weights) {
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
kernel_name << "_output_tile_3x3";
}
if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) {
kernel_name << "_output_tile_5x5";
}
}
break;
case Conv2dMethod::SlidingWindow:
kernel_name << "conv2d";
Expand Down Expand Up @@ -156,7 +167,7 @@ ValueRef prepack_weights(
const ValueRef vref,
const Conv2dMethod method) {
const auto original_sizes = graph.get_val(vref).toTensorRef().sizes;
const auto final_sizes = get_final_sizes(original_sizes, method);
const auto& final_sizes = get_final_sizes(original_sizes, method);

ValueRef v = graph.add_tensor(
final_sizes,
Expand All @@ -169,9 +180,9 @@ ValueRef prepack_weights(
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

api::ShaderInfo shader =
get_conv2d_shader(t, /*prepack_weights = */ true, method);
get_conv2d_shader(graph, t, /*prepack_weights = */ true, method, vref);

const auto padded_sizes = get_padded_sizes(original_sizes, method);
const auto& padded_sizes = get_padded_sizes(original_sizes, method);

graph.prepack_nodes().emplace_back(new PrepackNode(
graph,
Expand Down Expand Up @@ -210,13 +221,13 @@ Conv2dParams create_conv2d_params(
const ValueRef weight,
const KernelParams& p,
const bool transposed) {
const auto overlay_region = api::utils::make_ivec2({
const auto& overlay_region = api::utils::make_ivec2({
p.kernel_size.data[0] +
(p.kernel_size.data[0] - 1) * (p.dilation.data[0] - 1),
p.kernel_size.data[1] +
(p.kernel_size.data[1] - 1) * (p.dilation.data[1] - 1),
});
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
const int32_t in_group_size =
api::utils::safe_downcast<int32_t>(api::utils::align_up(
transposed ? weight_sizes.at(0) : weight_sizes.at(1), INT64_C(4)));
Expand Down Expand Up @@ -244,7 +255,7 @@ Conv2dMethod get_conv2d_method(
const ValueRef weight,
const int64_t groups,
const bool transposed) {
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
return Conv2dMethod::Depthwise;
}
Expand Down Expand Up @@ -298,8 +309,8 @@ void add_conv2d_node(

check_conv2d_params(kernel_params, transposed_val);

api::ShaderInfo shader =
get_conv2d_shader(t_out, /*prepack_weights = */ false, method);
api::ShaderInfo shader = get_conv2d_shader(
graph, t_out, /*prepack_weights = */ false, method, weight);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void resize_max_pool2d_node(
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);

// Height, Width
const auto new_out_sizes_hw = calc_out_sizes_hw(
const auto& new_out_sizes_hw = calc_out_sizes_hw(
*graph,
self.sizes(),
extra_args[0],
Expand Down