diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 26f461c062f..b33430a6bca 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -276,6 +276,7 @@ def register_binary_op(features: OpFeatures): exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.round.default, ] ) def register_unary_op(features: OpFeatures): @@ -576,7 +577,6 @@ def register_ported_op_all_packed_dims(features: OpFeatures): [ exir_ops.edge.aten.embedding.default, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - exir_ops.edge.aten.native_layer_norm.default, ] ) def register_ported_ops_with_prepacking(features: OpFeatures): @@ -587,6 +587,20 @@ def register_ported_ops_with_prepacking(features: OpFeatures): return features +# Ported ops that support their own prepacking. +@update_features( + [ + exir_ops.edge.aten.native_layer_norm.default, + ] +) +def register_ported_ops_with_prepacking_all_dims(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + valid_packed_dims=all_packed_dims, + ) + features.handles_own_prepacking = True + return features + + ####################### ## Utility functions ## ####################### diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl index f984821600b..d6c94661ace 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -15,6 +15,8 @@ #define VEC4_T ${texel_type(DTYPE)} +#define T ${texel_component_type(DTYPE)} + layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} @@ -25,9 +27,11 @@ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "out_limits")} -${layout_declare_ubo(B, "ivec4", "sizes")} -${layout_declare_ubo(B, "float", "epsilon")} +layout(push_constant) uniform PRECISION restrict Block { + ivec3 out_limits; + ivec4 sizes; + float epsilon; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -48,37 +52,97 @@ void main() { const int width = int(sizes.x); - VEC4_T mean = VEC4_T(0); - VEC4_T delta = VEC4_T(0); - VEC4_T delta2 = VEC4_T(0); - VEC4_T M2 = VEC4_T(0); - - // Use Welford's online algorithm to compute mean and variance in one pass - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm - ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); - for (int w = 0; w < width; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - delta = v - mean; - mean += delta / (w + 1); - delta2 = v - mean; - M2 += delta * delta2; - } - - VEC4_T var = M2 / width; - VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); - VEC4_T offset = -rstd * mean; - - for (int w = 0; w < width; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - // broadcasting - VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; - VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; - VEC4_T outtex = (v * rstd + offset) * weight + bias; - write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); + if (in_packed_dim != W_DIM) { + VEC4_T mean = VEC4_T(0); + VEC4_T delta = VEC4_T(0); + VEC4_T delta2 = VEC4_T(0); + VEC4_T M2 = VEC4_T(0); + + // Use Welford's online algorithm to compute mean and variance in one pass + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); + for (int w = 0; w < width; ++w) { + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); + delta = v - mean; + mean += delta / (w + 1); + delta2 = v - mean; + M2 += delta * delta2; + } + + VEC4_T var = M2 / width; + VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); + VEC4_T offset = -rstd * mean; + + for (int w = 0; w < width; ++w) { + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); + // broadcasting + VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; + VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; + VEC4_T outtex = (v * rstd + offset) * weight + bias; + write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); + } + + write_texel(t_mean, lpos, mean); + write_texel(t_rstd, lpos, rstd); + } else { + const int packed_width = divup4(width); + + T mean = T(0); + T delta = T(0); + T delta2 = T(0); + T M2 = T(0); + // Use Welford's online algorithm to compute mean and variance in one pass + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); + T width_counter = T(1); + + const bool has_unaligned_width = (width & 0x3) != 0; + const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width); + + // iterate through texels that are fully packed ie. has 4 components + for (int w = 0; w < fully_packed_4_comp_count; ++w) { + in_pos[in_axis_map.x] = w; + VEC4_T v = load_texel(t_in, in_pos); + for (int i=0; i<4; i++) { + delta = v[i] - mean; + mean += delta / width_counter; + delta2 = v[i] - mean; + M2 += delta * delta2; + width_counter++; + } + } + + // handle last texel if its not 4 aligned + if (has_unaligned_width) { + in_pos[in_axis_map.x] = fully_packed_4_comp_count; + const int remaining_width = width & 0x3; + + VEC4_T v = load_texel(t_in, in_pos); + for (int i=0; ivirtual_resize(mean_size); } -void check_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); -} - void add_native_layer_norm_node( ComputeGraph& graph, const ValueRef in, @@ -84,7 +79,7 @@ void add_native_layer_norm_node( vTensorPtr t_input = graph.get_tensor(in); float epsilon = graph.extract_scalar(eps); - check_args(*t_input, *t_out); + VK_CHECK_COND(check_same_packed_dim(*t_input, *t_out)); std::vector in_sizes = t_input->sizes(); @@ -106,11 +101,7 @@ void add_native_layer_norm_node( vkapi::MemoryAccessType::WRITE}, {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}}, // Shader params buffers - { - t_out->logical_limits_ubo(), - t_out->sizes_ubo(), - graph.create_params_buffer(epsilon), - }, + {}, // Specialization Constants { t_input->hashed_layout(), @@ -118,7 +109,12 @@ void add_native_layer_norm_node( }, // Resizing Logic resize_native_layer_norm_node, - {normalized_shape})); + {normalized_shape}, + { + graph.logical_limits_pc_of(out_val->at(0)), + graph.sizes_pc_of(out_val->at(0)), + PushConstantDataInfo(&epsilon, sizeof(epsilon)), + })); } void native_layer_norm(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 4bf73fad5a1..9a3ab002403 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -149,6 +149,7 @@ DEFINE_HARDSHRINK_FN(hardshrink); DEFINE_ACTIVATION_FN(hardswish); DEFINE_ACTIVATION_FN(hardsigmoid); DEFINE_LEAKY_RELU_FN(leaky_relu); +DEFINE_ACTIVATION_FN(round); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -168,6 +169,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.hardswish.default, hardswish); VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid); VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu); + VK_REGISTER_OP(aten.round.default, round); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index d2e09404ca0..85008a52ff0 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -385,6 +385,11 @@ def get_native_layer_norm_inputs(): ((S, XL, M1, M2), [M2], (M2), (M2), 0.001), ] ) + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kHeightPacked", + "utils::kChannelsPacked", + ] return test_suite @@ -1087,6 +1092,7 @@ def get_reduce_op_inputs(): "aten.hardswish.default", "aten.hardsigmoid.default", "aten.leaky_relu.default", + "aten.round.default", ] ) def get_unary_ops_inputs():