@@ -56,75 +56,63 @@ const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout);
56
56
// weight = (out_C, in_C / G, K),
57
57
// bias = (out_C,).
58
58
//
59
- // This implementation performs out_C shader invocations, where each invocation
59
+ // This implementation performs N x out_C x out_L shader invocations, where each invocation
60
60
// calculates the rolling kernel of the length dimension for each batch, i.e.,
61
- // computes out_L * N results.
62
- //
63
- // Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
64
- // shader invocations, where each invocation computes 1 result. But that
65
- // performs worse.
61
+ // computes out_L results.
66
62
void main() {
67
63
const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
68
64
69
65
if (any (greaterThanEqual (lpos, out_limits))) {
70
66
return ;
71
67
}
72
68
73
- int in_length = in_sizes.x;
74
- int batch_size = in_sizes.z;
75
-
76
69
// "out_c" is the output's channel index where we write our result.
77
70
// Across shader invocations, this is the only value that varies.
78
- int out_c = lpos.y;
79
- VEC4_T bias = load_texel_lpos(bias_in, ivec3 (out_c, 0 , 0 ), bias_axis_map);
71
+ const int out_c = lpos.y;
80
72
81
73
// "in_c" tracks the input's channel start index.
82
74
// We iterate over the input group that corresponds to the output group.
83
- int c_start = (out_c / out_group_size) * in_group_size;
84
- int c_end = c_start + in_group_size;
75
+ const int c_start = (out_c / out_group_size) * in_group_size;
76
+ const int c_end = c_start + in_group_size;
77
+
78
+ // "out_l" tracks the output's length index where we write our result.
79
+ const int out_l = lpos.x;
80
+
81
+ // "N" is the batch index
82
+ const int N = lpos.z;
85
83
86
84
// "in_l" tracks the input's length start index for our input-kernel overlay
87
85
// region.
88
- int l_start = - padding;
89
- int l_end = in_length + padding - dilation * (kernel_size - 1 );
90
-
91
- // Since the input/output tensors are channel-packed, which is along the
92
- // batch dimension, we can batch-read/write four elements at a time.
93
- for (int n = 0 ; n < batch_size; n += 4 ) {
94
- // "out_l" tracks the output's length index where we write our result.
95
- int out_l = 0 ;
96
-
97
- for (int in_l = l_start; in_l < l_end; in_l += stride, ++ out_l) {
98
- VEC4_T sum = VEC4_T(0 );
99
-
100
- for (int in_c = c_start; in_c < c_end; ++ in_c) {
101
- // "k" tracks the kernel's index for our input-kernel computation.
102
- // It reads out-of-bound zeros, but trying to avoid them complicates
103
- // for-loop conditions, which results in worse performance.
104
-
105
- // The weight tensor is channel-packed. It may not be trival choice for
106
- // performance reason since need to have more data fetch. The reason is
107
- // for some sequence model, we found that the weight tensor
108
- // (out_channel, in_channel / group, kernel) often has a large
109
- // out_channel >> kernel, leading to non-optimal use of memory as the
110
- // weight tensor gets very deep. As a mitigation, we use channel-packing
111
- // for the weight tensor, yielding a 75% reduction in weight-tensor
112
- // memory.
113
-
114
- // It is possible to further reduce the memory footprint by swapping the
115
- // dimensions, using x extent for out_channel, and y for kernel.
116
- for (int k = 0 ; k < kernel_size; k += 1 ) {
117
- const ivec3 w_lpos = ivec3 (k, in_c % in_group_size, out_c / 4 );
118
- const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
119
- VEC4_T weight = VEC4_T(weight_texel[out_c % 4 ]);
120
-
121
- ivec3 in_pos = lpos_to_pos(ivec3 (in_l + k * dilation, in_c, n / 4 ), in_axis_map);
122
- sum = fma(weight, load_texel(t_in, in_pos), sum);
123
- }
124
- }
125
-
126
- const ivec3 out_lpos = ivec3 (out_l, out_c, n / 4 );
127
- write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
86
+ const int in_l = out_l * stride - padding;
87
+ VEC4_T sum = VEC4_T(0 );
88
+
89
+ for (int in_c = c_start; in_c < c_end; ++ in_c) {
90
+ // "k" tracks the kernel's index for our input-kernel computation.
91
+ // It reads out-of-bound zeros, but trying to avoid them complicates
92
+ // for-loop conditions, which results in worse performance.
93
+
94
+ // The weight tensor is channel-packed. It may not be trival choice for
95
+ // performance reason since need to have more data fetch. The reason is
96
+ // for some sequence model, we found that the weight tensor
97
+ // (out_channel, in_channel / group, kernel) often has a large
98
+ // out_channel >> kernel, leading to non-optimal use of memory as the
99
+ // weight tensor gets very deep. As a mitigation, we use channel-packing
100
+ // for the weight tensor, yielding a 75% reduction in weight-tensor
101
+ // memory.
102
+
103
+ // It is possible to further reduce the memory footprint by swapping the
104
+ // dimensions, using x extent for out_channel, and y for kernel.
105
+ for (int k = 0 ; k < kernel_size; k++ ) {
106
+ const ivec3 w_lpos = ivec3 (k, in_c % in_group_size, out_c / 4 );
107
+ const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
108
+ VEC4_T weight = VEC4_T(weight_texel[out_c % 4 ]);
109
+
110
+ const ivec3 in_pos = lpos_to_pos(ivec3 (in_l + k * dilation, in_c, N), in_axis_map);
111
+ sum = fma(weight, load_texel(t_in, in_pos), sum);
128
112
}
129
113
}
114
+
115
+ const VEC4_T bias = load_texel_lpos(bias_in, ivec3 (out_c, 0 , 0 ), bias_axis_map);
116
+ const ivec3 out_lpos = ivec3 (out_l, out_c, N);
117
+ write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
130
118
}
0 commit comments