@@ -43,106 +43,275 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
4343const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
4444const lowp int out_packed_dim = unhash_packed_dim(out_layout);
4545
46- void main() {
47- const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
46+ #define MAX_WORKGROUP_SIZE 64
47+
48+ // Shared memory factor increases shared memory allocation by a scale that should either be 1 or a power of 2.
49+ //
50+ // Increasing factor allows more data to be stored in shared memory and increase thread utilization during reduction.
51+ // Why? Because when performing reduction, the number of active threads becomes half in each iteration.
52+ // Increasing scaling factor increases the thread occupancy and hence utilize the GPU better.
53+ // eg.
54+ // If local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 1, 32 elements will be loaded into shared memory.
55+ // First iteration of reduce will have 16 threads sum up 32 elements.
56+ // Second iteration will have 8 threads sum up 16 elements from previous iteration and so on.
57+ // So thread utilization starts at 50%.
58+ //
59+ // By contrast if local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 2, 64 elements will be loaded into shared memory.
60+ // First iteration of reduce will have 32 threads sum up 64 elements.
61+ // Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
62+ // Thus thread utilization starts at 100%.
63+ #define SHARED_MEMORY_FACTOR 2
64+
65+ #define offset_pos_index(index) ((index) + ((index) >> 2 ))
66+
67+ shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
68+
69+ // Function to reduce input data in workgroup's x dimension
70+ //
71+ // The implementation resembles reduction as depicted below
72+ // | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | 2 | 3 | 2 | 7 | 0 | 11 | 0 | 2 | current_stride -> 1
73+ // | / | / | / | / | / | / | / | /
74+ // | / | / | / | / | / | / | / | /
75+ // | / | / | / | / | / | / | / | /
76+ // | 11 | 1 | 9 | 1 | 2 | 2 | 8 | 5 | 5 | 3 | 9 | 7 | 11 | 11 | 2 | 2 | current_stride -> 2
77+ // | / | / | / | /
78+ // | / | / | / | /
79+ // | / | / | / | /
80+ // | 20 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |14 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 4
81+ // | / | /
82+ // | / | /
83+ // | / | /
84+ // | / | /
85+ // | / | /
86+ // | 30 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 8
87+ // | /
88+ // | /
89+ // | /
90+ // | /
91+ // | /
92+ // | /
93+ // | /
94+ // | /
95+ // | /
96+ // | 57 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride = -> 16
97+ //
98+ // Threads access shared index in following pattern
99+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 1
100+ // Shared Index | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | X | X | X | X | X | X | X | X | index *= 1
101+ //
102+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 2
103+ // Shared Index | 0 | 4 | 8 | 12 | X | X | X | X | X | X | X | X | X | X | X | X | index *= 2
104+ //
105+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 4
106+ // Shared Index | 0 | 8 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 4
107+ //
108+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 8
109+ // Shared Index | 0 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 8
110+
111+ void reduce_input(const int width_stride, const int shared_idx_offset) {
112+ // wait for all shared memory writes to finish
113+ memoryBarrierShared();
114+ barrier();
115+
116+ // loop log(width_stride) times
117+ for (int current_stride = 1 , index = int (gl_LocalInvocationID.x << 1 ); current_stride < width_stride; current_stride *= 2 , index <<= 1 ) {
118+ // if the index at this thread is within the width stride
119+ if (index < width_stride) {
120+ const int local_shared_idx = shared_idx_offset + index;
121+ // add the value at current stride to this thread's value
122+ shared_input[offset_pos_index(local_shared_idx)] += shared_input[offset_pos_index(local_shared_idx + current_stride)];
123+ }
48124
49- if ( any ( greaterThanEqual (lpos, out_limits))) {
50- return ;
125+ memoryBarrierShared();
126+ barrier() ;
51127 }
128+ }
52129
130+ void reduce_non_packed_dim() {
131+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
53132 const int width = int (sizes.x);
133+ ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
54134
55- if (in_packed_dim != W_DIM) {
56- VEC4_T mean = VEC4_T(0 );
57- VEC4_T delta = VEC4_T(0 );
58- VEC4_T delta2 = VEC4_T(0 );
59- VEC4_T M2 = VEC4_T(0 );
60-
61- // Use Welford's online algorithm to compute mean and variance in one pass
62- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
63- ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
64- for (int w = 0 ; w < width; ++ w) {
65- in_pos[in_axis_map.x] = w;
66- VEC4_T v = load_texel(t_in, in_pos);
67- delta = v - mean;
68- mean += delta / (w + 1 );
69- delta2 = v - mean;
70- M2 += delta * delta2;
135+ // width batch read stride
136+ const int width_stride = int (gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
137+
138+ // local memory starting offset for this thread
139+ const int shared_idx_offset = width_stride * int (gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
140+
141+ // local memory index for this thread
142+ const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
143+
144+ VEC4_T mean = VEC4_T(0 );
145+ VEC4_T var = VEC4_T(0 );
146+
147+ // Loop over the width in stride increments
148+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
149+ // Read input in shared memory
150+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
151+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
152+
153+ VEC4_T in_val = VEC4_T(0 );
154+ if (all (lessThan (in_pos, out_limits))) {
155+ in_val = load_texel(t_in, in_pos);
156+ }
157+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
71158 }
72159
73- VEC4_T var = M2 / width;
74- VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
75- VEC4_T offset = - rstd * mean;
76-
77- for (int w = 0 ; w < width; ++ w) {
78- in_pos[in_axis_map.x] = w;
79- VEC4_T v = load_texel(t_in, in_pos);
80- // broadcasting
81- VEC4_T weight = load_texel(t_weight, ivec3 (w, 0 , 0 )).xxxx;
82- VEC4_T bias = load_texel(t_bias, ivec3 (w, 0 , 0 )).xxxx;
83- VEC4_T outtex = (v * rstd + offset) * weight + bias;
84- write_texel_lpos(t_out, ivec3 (w, lpos.y, lpos.z), outtex, out_axis_map);
160+ reduce_input(width_stride, shared_idx_offset);
161+ mean += shared_input[offset_pos_index(shared_idx_offset)];
162+ }
163+
164+ mean /= width;
165+
166+ memoryBarrierShared();
167+ barrier();
168+
169+ // Loop over the width in stride increments
170+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
171+ // Read input in shared memory
172+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
173+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
174+
175+ VEC4_T in_val = mean;
176+ if (all (lessThan (in_pos, out_limits))) {
177+ in_val = load_texel(t_in, in_pos);
178+ }
179+
180+ const VEC4_T delta = in_val - mean;
181+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
85182 }
86183
184+ reduce_input(width_stride, shared_idx_offset);
185+ var += shared_input[offset_pos_index(shared_idx_offset)];
186+ }
187+
188+ var /= width;
189+
190+ VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
191+ VEC4_T offset = - rstd * mean;
192+
193+ VEC4_T v = load_texel(t_in, lpos);
194+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 )).xxxx;
195+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 )).xxxx;
196+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
197+
198+ if (all (lessThan (lpos, out_limits))) {
199+ write_texel_lpos(t_out, lpos, outtex, out_axis_map);
200+ }
201+
202+ if (gl_GlobalInvocationID.x == 0 ) {
87203 write_texel(t_mean, lpos, mean);
88204 write_texel(t_rstd, lpos, rstd);
89- } else {
90- const int packed_width = divup4(width);
91-
92- T mean = T(0 );
93- T delta = T(0 );
94- T delta2 = T(0 );
95- T M2 = T(0 );
96- // Use Welford's online algorithm to compute mean and variance in one pass
97- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
98- ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
99- T width_counter = T(1 );
100-
101- const bool has_unaligned_width = (width & 0x3) != 0 ;
102- const int fully_packed_4_comp_count = packed_width - mix (0 , 1 , has_unaligned_width);
103-
104- // iterate through texels that are fully packed ie. has 4 components
105- for (int w = 0 ; w < fully_packed_4_comp_count; ++ w) {
106- in_pos[in_axis_map.x] = w;
107- VEC4_T v = load_texel(t_in, in_pos);
108- for (int i= 0 ; i< 4 ; i++ ) {
109- delta = v[i] - mean;
110- mean += delta / width_counter;
111- delta2 = v[i] - mean;
112- M2 += delta * delta2;
113- width_counter++ ;
205+ }
206+ }
207+
208+ void reduce_packed_dim() {
209+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
210+ const int width = int (sizes.x);
211+ ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
212+
213+ // width batch read stride
214+ const int width_stride = int (gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
215+
216+ // local memory starting offset for this thread
217+ const int shared_idx_offset = width_stride * int (gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
218+
219+ // local memory index for this thread
220+ const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
221+
222+ const int last_packed_width_index = divup4(width) - 1 ;
223+ T mean = T(0 );
224+ T var = T(0 );
225+ const int remain = width & 3 ;
226+
227+ const int in_pos_x_limit = out_limits[in_axis_map.x];
228+
229+ // Loop over the width in stride increments
230+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
231+ // Read input in shared memory
232+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
233+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
234+ in_pos[in_axis_map.x] = in_pos_x;
235+
236+ VEC4_T in_val = VEC4_T(0 );
237+ if (in_pos_x < in_pos_x_limit) {
238+ in_val = load_texel(t_in, in_pos);
114239 }
115- }
116240
117- // handle last texel if its not 4 aligned
118- if (has_unaligned_width) {
119- in_pos[in_axis_map.x] = fully_packed_4_comp_count;
120- const int remaining_width = width & 0x3;
121-
122- VEC4_T v = load_texel(t_in, in_pos);
123- for (int i= 0 ; i< remaining_width; i++ ) {
124- delta = v[i] - mean;
125- mean += delta / width_counter;
126- delta2 = v[i] - mean;
127- M2 += delta * delta2;
128- width_counter++ ;
241+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
242+ const int remain_inv = 4 - remain;
243+ in_val.y = mix (in_val.y, T(0 ), remain_inv > 2 );
244+ in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
245+ in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
129246 }
247+
248+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
130249 }
131250
132- T var = M2 / (width_counter - 1 );
133- T rstd = inversesqrt (var + epsilon);
134- T offset = - rstd * mean;
135-
136- for (int w = 0 ; w < packed_width; ++ w) {
137- in_pos[in_axis_map.x] = w;
138- VEC4_T v = load_texel(t_in, in_pos);
139- VEC4_T weight = load_texel(t_weight, ivec3 (w, 0 , 0 ));
140- VEC4_T bias = load_texel(t_bias, ivec3 (w, 0 , 0 ));
141- VEC4_T outtex = (v * rstd + offset) * weight + bias;
142- write_texel_lpos(t_out, ivec3 (w, lpos.y, lpos.z), outtex, out_axis_map);
251+ reduce_input(width_stride, shared_idx_offset);
252+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253+ mean += val.x + val.y + val.z + val.w;
254+ }
255+
256+ mean /= width;
257+
258+ memoryBarrierShared();
259+ barrier();
260+
261+ // Loop over the width in stride increments
262+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
263+ // Read input in shared memory
264+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
265+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
266+ in_pos[in_axis_map.x] = in_pos_x;
267+
268+ VEC4_T in_val = VEC4_T(mean);
269+ if (in_pos_x < in_pos_x_limit) {
270+ in_val = load_texel(t_in, in_pos);
271+ }
272+
273+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
274+ const int remain_inv = 4 - remain;
275+ in_val.y = mix (in_val.y, mean.x, remain_inv > 2 );
276+ in_val.z = mix (in_val.z, mean.x, remain_inv > 1 );
277+ in_val.w = mix (in_val.w, mean.x, remain_inv > 0 );
278+ }
279+
280+ const VEC4_T delta = in_val - mean;
281+ const VEC4_T delta2 = delta * delta;
282+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
143283 }
144284
285+ reduce_input(width_stride, shared_idx_offset);
286+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287+ var += val.x + val.y + val.z + val.w;
288+ }
289+
290+ var /= width;
291+
292+ T rstd = pow (var + epsilon, T(- 0.5 ));
293+ T offset = - rstd * mean;
294+
295+ VEC4_T v = load_texel(t_in, lpos);
296+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 ));
297+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 ));
298+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
299+
300+ if (all (lessThan (lpos, out_limits))) {
301+ write_texel_lpos(t_out, lpos, outtex, out_axis_map);
302+ }
303+
304+ if (gl_GlobalInvocationID.x == 0 ) {
145305 write_texel(t_mean, lpos, VEC4_T(mean));
146306 write_texel(t_rstd, lpos, VEC4_T(rstd));
147307 }
148308}
309+
310+ void main() {
311+ // if packed dimension width
312+ if (in_packed_dim != W_DIM) {
313+ reduce_non_packed_dim();
314+ } else {
315+ reduce_packed_dim();
316+ }
317+ }
0 commit comments