Skip to content

Commit 18e0be3

Browse files
committed
[ET-VK] Simplifying conv1d op shader by changing it to process one output texel per thread.
Pull Request resolved: #10665 This diff changes conv1d shader to process one output texel per thread, increasing GPU occupancy and improve performance. ghstack-source-id: 282024192 Differential Revision: [D74097560](https://our.internmc.facebook.com/intern/diff/D74097560/)
1 parent ab09362 commit 18e0be3

File tree

2 files changed

+60
-65
lines changed

2 files changed

+60
-65
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl

Lines changed: 41 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -56,75 +56,63 @@ const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout);
5656
// weight = (out_C, in_C / G, K),
5757
// bias = (out_C,).
5858
//
59-
// This implementation performs out_C shader invocations, where each invocation
59+
// This implementation performs N x out_C x out_L shader invocations, where each invocation
6060
// calculates the rolling kernel of the length dimension for each batch, i.e.,
61-
// computes out_L * N results.
62-
//
63-
// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
64-
// shader invocations, where each invocation computes 1 result. But that
65-
// performs worse.
61+
// computes out_L results.
6662
void main() {
6763
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
6864

6965
if (any(greaterThanEqual(lpos, out_limits))) {
7066
return;
7167
}
7268

73-
int in_length = in_sizes.x;
74-
int batch_size = in_sizes.z;
75-
7669
// "out_c" is the output's channel index where we write our result.
7770
// Across shader invocations, this is the only value that varies.
78-
int out_c = lpos.y;
79-
VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);
71+
const int out_c = lpos.y;
8072

8173
// "in_c" tracks the input's channel start index.
8274
// We iterate over the input group that corresponds to the output group.
83-
int c_start = (out_c / out_group_size) * in_group_size;
84-
int c_end = c_start + in_group_size;
75+
const int c_start = (out_c / out_group_size) * in_group_size;
76+
const int c_end = c_start + in_group_size;
77+
78+
// "out_l" tracks the output's length index where we write our result.
79+
const int out_l = lpos.x;
80+
81+
// "N" is the batch index
82+
const int N = lpos.z;
8583

8684
// "in_l" tracks the input's length start index for our input-kernel overlay
8785
// region.
88-
int l_start = -padding;
89-
int l_end = in_length + padding - dilation * (kernel_size - 1);
90-
91-
// Since the input/output tensors are channel-packed, which is along the
92-
// batch dimension, we can batch-read/write four elements at a time.
93-
for (int n = 0; n < batch_size; n += 4) {
94-
// "out_l" tracks the output's length index where we write our result.
95-
int out_l = 0;
96-
97-
for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
98-
VEC4_T sum = VEC4_T(0);
99-
100-
for (int in_c = c_start; in_c < c_end; ++in_c) {
101-
// "k" tracks the kernel's index for our input-kernel computation.
102-
// It reads out-of-bound zeros, but trying to avoid them complicates
103-
// for-loop conditions, which results in worse performance.
104-
105-
// The weight tensor is channel-packed. It may not be trival choice for
106-
// performance reason since need to have more data fetch. The reason is
107-
// for some sequence model, we found that the weight tensor
108-
// (out_channel, in_channel / group, kernel) often has a large
109-
// out_channel >> kernel, leading to non-optimal use of memory as the
110-
// weight tensor gets very deep. As a mitigation, we use channel-packing
111-
// for the weight tensor, yielding a 75% reduction in weight-tensor
112-
// memory.
113-
114-
// It is possible to further reduce the memory footprint by swapping the
115-
// dimensions, using x extent for out_channel, and y for kernel.
116-
for (int k = 0; k < kernel_size; k += 1) {
117-
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4);
118-
const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
119-
VEC4_T weight = VEC4_T(weight_texel[out_c % 4]);
120-
121-
ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
122-
sum = fma(weight, load_texel(t_in, in_pos), sum);
123-
}
124-
}
125-
126-
const ivec3 out_lpos = ivec3(out_l, out_c, n / 4);
127-
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
86+
const int in_l = out_l * stride - padding;
87+
VEC4_T sum = VEC4_T(0);
88+
89+
for (int in_c = c_start; in_c < c_end; ++in_c) {
90+
// "k" tracks the kernel's index for our input-kernel computation.
91+
// It reads out-of-bound zeros, but trying to avoid them complicates
92+
// for-loop conditions, which results in worse performance.
93+
94+
// The weight tensor is channel-packed. It may not be trival choice for
95+
// performance reason since need to have more data fetch. The reason is
96+
// for some sequence model, we found that the weight tensor
97+
// (out_channel, in_channel / group, kernel) often has a large
98+
// out_channel >> kernel, leading to non-optimal use of memory as the
99+
// weight tensor gets very deep. As a mitigation, we use channel-packing
100+
// for the weight tensor, yielding a 75% reduction in weight-tensor
101+
// memory.
102+
103+
// It is possible to further reduce the memory footprint by swapping the
104+
// dimensions, using x extent for out_channel, and y for kernel.
105+
for (int k = 0; k < kernel_size; k++) {
106+
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4);
107+
const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
108+
VEC4_T weight = VEC4_T(weight_texel[out_c % 4]);
109+
110+
const ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, N), in_axis_map);
111+
sum = fma(weight, load_texel(t_in, in_pos), sum);
128112
}
129113
}
114+
115+
const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);
116+
const ivec3 out_lpos = ivec3(out_l, out_c, N);
117+
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
130118
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -505,17 +505,24 @@ void add_conv1d_node(
505505

506506
check_conv_args(*t_in, *t_out);
507507

508-
int32_t in_channels = in_sizes.at(1);
509-
int32_t out_channels = weight_sizes.at(0);
510-
int32_t kernel_size = weight_sizes.at(2);
511-
int32_t stride_size = graph.get_int_list(stride)->at(0);
512-
int32_t padding_size = graph.get_int_list(padding)->at(0);
513-
int32_t dilation_size = graph.get_int_list(dilation)->at(0);
514-
int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
515-
int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);
516-
517-
utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
518-
utils::uvec3 local_size = {1, 64, 1};
508+
const int32_t in_channels = in_sizes.at(1);
509+
const int32_t out_channels = weight_sizes.at(0);
510+
const int32_t kernel_size = weight_sizes.at(2);
511+
const int32_t stride_size = graph.get_int_list(stride)->at(0);
512+
const int32_t padding_size = graph.get_int_list(padding)->at(0);
513+
const int32_t dilation_size = graph.get_int_list(dilation)->at(0);
514+
const int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
515+
const int32_t out_group_size =
516+
static_cast<int64_t>(out_channels / groups_val);
517+
518+
const utils::uvec3 global_size = {
519+
// out length
520+
graph.size_at<uint32_t>(-1, out),
521+
// out channels
522+
static_cast<uint32_t>(out_channels),
523+
// out batches
524+
utils::div_up_4(graph.size_at<uint32_t>(-3, out))};
525+
const utils::uvec3 local_size = graph.create_local_wg_size(global_size);
519526

520527
Kernel1dParams kernel_params = {
521528
kernel_size,
@@ -525,7 +532,7 @@ void add_conv1d_node(
525532
in_group_size,
526533
out_group_size};
527534

528-
OutputParams out_params = {out_min_val, out_max_val};
535+
const OutputParams out_params = {out_min_val, out_max_val};
529536

530537
std::string kernel_name("conv1d");
531538
if (clamp_out) {

0 commit comments

Comments
 (0)