Skip to content

Commit ef6fbce

Browse files
committed
[ET-VK] Minor performance improvements to native layer norm.
This diff introduces minor performance improvements to the native layer norm function in the Vulkan backend of Executorch. In this new approach: The mean and variance values are calculated in 2 separate passes. Shader is dispatched based on input texture size, and input texel is read and stored in shared memory. Input stored in shard memory is then summed up using a reduce function. This implementation better utilizes a GPUs parallel processing capabilities. Differential Revision: [D72430290](https://our.internmc.facebook.com/intern/diff/D72430290/) ghstack-source-id: 276053981 Pull Request resolved: #9892
1 parent e9c2315 commit ef6fbce

File tree

2 files changed

+161
-77
lines changed

2 files changed

+161
-77
lines changed

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

Lines changed: 159 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -43,106 +43,190 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
4343
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
4444
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
4545

46-
void main() {
47-
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
46+
#define SHARED_MEMORY_FACTOR 2
47+
#define MAX_WORKGROUP_SIZE 64
48+
49+
#define offset_pos_index(index) ((index) + ((index) >> 2))
50+
51+
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
52+
53+
// function to reduce input data in workgroup's x dimension
54+
void reduce_input(const int width_stride, const int shared_idx_offset) {
55+
// wait for all shared memory writes to finish
56+
memoryBarrierShared();
57+
barrier();
58+
59+
// loop log(width_stride) times
60+
for (int current_stride = 1, index = int(gl_LocalInvocationID.x << 1); current_stride < width_stride; current_stride *= 2, index <<= 1) {
61+
// if the index at this thread is within the width stride
62+
if (index < width_stride) {
63+
const int local_shared_idx = shared_idx_offset + index;
64+
// add the value at current stride to this thread's value
65+
shared_input[offset_pos_index(local_shared_idx)] += shared_input[offset_pos_index(local_shared_idx + current_stride)];
66+
}
4867

49-
if (any(greaterThanEqual(lpos, out_limits))) {
50-
return;
68+
memoryBarrierShared();
69+
barrier();
5170
}
71+
}
5272

73+
void main() {
74+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
5375
const int width = int(sizes.x);
5476

77+
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
78+
79+
// width batch read stride
80+
const int width_stride = int(gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
81+
82+
// local memory starting offset for this thread
83+
const int shared_idx_offset = width_stride * int(gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
84+
85+
// local memory index for this thread
86+
const int shared_idx = shared_idx_offset + int(gl_LocalInvocationID.x);
87+
88+
// if packed dimension width
5589
if (in_packed_dim != W_DIM) {
5690
VEC4_T mean = VEC4_T(0);
57-
VEC4_T delta = VEC4_T(0);
58-
VEC4_T delta2 = VEC4_T(0);
59-
VEC4_T M2 = VEC4_T(0);
60-
61-
// Use Welford's online algorithm to compute mean and variance in one pass
62-
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
63-
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
64-
for (int w = 0; w < width; ++w) {
65-
in_pos[in_axis_map.x] = w;
66-
VEC4_T v = load_texel(t_in, in_pos);
67-
delta = v - mean;
68-
mean += delta / (w + 1);
69-
delta2 = v - mean;
70-
M2 += delta * delta2;
91+
VEC4_T var = VEC4_T(0);
92+
93+
// Loop over the width in stride increments
94+
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
95+
// Read input in shared memory
96+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
97+
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
98+
99+
VEC4_T in_val = VEC4_T(0);
100+
if (all(lessThan(in_pos, out_limits))) {
101+
in_val = load_texel(t_in, in_pos);
102+
}
103+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
104+
}
105+
106+
reduce_input(width_stride, shared_idx_offset);
107+
mean += shared_input[offset_pos_index(shared_idx_offset)];
108+
}
109+
110+
mean /= width;
111+
112+
// Loop over the width in stride increments
113+
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
114+
// Read input in shared memory
115+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
116+
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
117+
118+
VEC4_T in_val = mean;
119+
if (all(lessThan(in_pos, out_limits))) {
120+
in_val = load_texel(t_in, in_pos);
121+
}
122+
123+
const VEC4_T delta = in_val - mean;
124+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
125+
}
126+
127+
reduce_input(width_stride, shared_idx_offset);
128+
var += shared_input[offset_pos_index(shared_idx_offset)];
71129
}
72130

73-
VEC4_T var = M2 / width;
131+
var /= width;
132+
74133
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
75134
VEC4_T offset = -rstd * mean;
76135

77-
for (int w = 0; w < width; ++w) {
78-
in_pos[in_axis_map.x] = w;
79-
VEC4_T v = load_texel(t_in, in_pos);
80-
// broadcasting
81-
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
82-
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
83-
VEC4_T outtex = (v * rstd + offset) * weight + bias;
84-
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
136+
VEC4_T v = load_texel(t_in, lpos);
137+
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0)).xxxx;
138+
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0)).xxxx;
139+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
140+
if (all(lessThan(lpos, out_limits))) {
141+
write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
85142
}
86143

87-
write_texel(t_mean, lpos, mean);
88-
write_texel(t_rstd, lpos, rstd);
144+
if (gl_GlobalInvocationID.x == 0) {
145+
write_texel(t_mean, lpos, mean);
146+
write_texel(t_rstd, lpos, rstd);
147+
}
89148
} else {
90-
const int packed_width = divup4(width);
91-
149+
const int last_packed_width_index = divup4(width) - 1;
92150
T mean = T(0);
93-
T delta = T(0);
94-
T delta2 = T(0);
95-
T M2 = T(0);
96-
// Use Welford's online algorithm to compute mean and variance in one pass
97-
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
98-
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
99-
T width_counter = T(1);
100-
101-
const bool has_unaligned_width = (width & 0x3) != 0;
102-
const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width);
103-
104-
// iterate through texels that are fully packed ie. has 4 components
105-
for (int w = 0; w < fully_packed_4_comp_count; ++w) {
106-
in_pos[in_axis_map.x] = w;
107-
VEC4_T v = load_texel(t_in, in_pos);
108-
for (int i=0; i<4; i++) {
109-
delta = v[i] - mean;
110-
mean += delta / width_counter;
111-
delta2 = v[i] - mean;
112-
M2 += delta * delta2;
113-
width_counter++;
151+
T var = T(0);
152+
const int remain = width & 3;
153+
154+
const int in_pos_x_limit = out_limits[in_axis_map.x];
155+
156+
// Loop over the width in stride increments
157+
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
158+
// Read input in shared memory
159+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
160+
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
161+
in_pos[in_axis_map.x] = in_pos_x;
162+
163+
VEC4_T in_val = VEC4_T(0);
164+
if (in_pos_x < in_pos_x_limit) {
165+
in_val = load_texel(t_in, in_pos);
166+
}
167+
168+
if (in_pos_x == last_packed_width_index && remain != 0) {
169+
const int remain_inv = 4 - remain;
170+
in_val.y = mix(in_val.y, T(0), remain_inv > 2);
171+
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
172+
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
173+
}
174+
175+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
114176
}
177+
178+
reduce_input(width_stride, shared_idx_offset);
179+
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
180+
mean += val.x + val.y + val.z + val.w;
115181
}
116182

117-
// handle last texel if its not 4 aligned
118-
if (has_unaligned_width) {
119-
in_pos[in_axis_map.x] = fully_packed_4_comp_count;
120-
const int remaining_width = width & 0x3;
121-
122-
VEC4_T v = load_texel(t_in, in_pos);
123-
for (int i=0; i<remaining_width; i++) {
124-
delta = v[i] - mean;
125-
mean += delta / width_counter;
126-
delta2 = v[i] - mean;
127-
M2 += delta * delta2;
128-
width_counter++;
183+
mean /= width;
184+
185+
// Loop over the width in stride increments
186+
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
187+
// Read input in shared memory
188+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
189+
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
190+
in_pos[in_axis_map.x] = in_pos_x;
191+
192+
VEC4_T in_val = VEC4_T(mean);
193+
if (in_pos_x < in_pos_x_limit) {
194+
in_val = load_texel(t_in, in_pos);
195+
}
196+
197+
if (in_pos_x == last_packed_width_index && remain != 0) {
198+
const int remain_inv = 4 - remain;
199+
in_val.y = mix(in_val.y, mean.x, remain_inv > 2);
200+
in_val.z = mix(in_val.z, mean.x, remain_inv > 1);
201+
in_val.w = mix(in_val.w, mean.x, remain_inv > 0);
202+
}
203+
204+
const VEC4_T delta = in_val - mean;
205+
const VEC4_T delta2 = delta * delta;
206+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
129207
}
208+
209+
reduce_input(width_stride, shared_idx_offset);
210+
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
211+
var += val.x + val.y + val.z + val.w;
130212
}
131213

132-
T var = M2 / (width_counter - 1);
133-
T rstd = inversesqrt(var + epsilon);
214+
var /= width;
215+
216+
T rstd = pow(var + epsilon, T(-0.5));
134217
T offset = -rstd * mean;
135218

136-
for (int w = 0; w < packed_width; ++w) {
137-
in_pos[in_axis_map.x] = w;
138-
VEC4_T v = load_texel(t_in, in_pos);
139-
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0));
140-
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0));
141-
VEC4_T outtex = (v * rstd + offset) * weight + bias;
142-
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
219+
VEC4_T v = load_texel(t_in, lpos);
220+
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0));
221+
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0));
222+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
223+
if (all(lessThan(lpos, out_limits))) {
224+
write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
143225
}
144226

145-
write_texel(t_mean, lpos, VEC4_T(mean));
146-
write_texel(t_rstd, lpos, VEC4_T(rstd));
227+
if (gl_GlobalInvocationID.x == 0) {
228+
write_texel(t_mean, lpos, VEC4_T(mean));
229+
write_texel(t_rstd, lpos, VEC4_T(rstd));
230+
}
147231
}
148232
}

backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ void add_native_layer_norm_node(
8383

8484
std::vector<int64_t> in_sizes = t_input->sizes();
8585

86-
utils::uvec3 global_size = t_mean->logical_limits();
87-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
86+
utils::uvec3 global_size = t_out->logical_limits();
87+
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
8888

8989
std::string kernel_name("native_layer_norm");
9090
kernel_name.reserve(kShaderNameReserve);

0 commit comments

Comments
 (0)