diff --git a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc index 0cab454a5a530..f3bccec4872fc 100644 --- a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc @@ -13,23 +13,25 @@ namespace onnxruntime { namespace webgpu { Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseIndicesTypeAlias); + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AdditionalImplementation() << "var workgroup_shared_sum : array;\n" - << "var workgroup_shared_squared_sum : array;\n" + shader.AdditionalImplementation() << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n" + << "var workgroup_shared_sum : array;\n" + << "var workgroup_shared_squared_sum : array;\n" << "const workgroup_size = " << workgroup_size_ << ";\n"; + shader.MainFunctionBody() << " let batch = workgroup_idx / uniforms.x_shape[1];\n" << " let channel = workgroup_idx % uniforms.x_shape[1];\n" << " let hight = uniforms.x_shape[2];\n" << " // initialize workgroup memory<< \n" - << " var sum = x_value_t(0);\n" - << " var squared_sum = x_value_t(0);\n" + << " var sum = f32_val_t(0);\n" + << " var squared_sum = f32_val_t(0);\n" << " for (var h = local_idx; h < hight; h += workgroup_size) {\n" << " let indices = x_indices_t(batch, channel, h);\n" - << " let value =" << input.GetByIndices("indices") << ";\n" + << " let value = f32_val_t(" << input.GetByIndices("indices") << ");\n" << " sum += value;\n" << " squared_sum += value * value;\n" << " }\n" @@ -44,12 +46,12 @@ Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader) << " workgroupBarrier();\n" << " }\n" << " if (local_idx == 0) {\n" - << " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / x_element_t(hight * " << components_ << ");\n" - << " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / x_element_t(hight * " << components_ << ");\n" - << " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + x_element_t(" << std::to_string(epsilon_) << "));\n" - << " let channel_scale = inv_std_dev * " << scale.GetByOffset("channel") << ";\n" - << " let channel_shift = " << bias.GetByOffset("channel") << " - sum_final * channel_scale;\n" - << " " << output.SetByOffset("workgroup_idx", "output_value_t(channel_scale, channel_shift)") << ";\n" + << " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n" + << " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n" + << " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(" << std::to_string(epsilon_) << "));\n" + << " let channel_scale = inv_std_dev * f32(" << scale.GetByOffset("channel") << ");\n" + << " let channel_shift = f32(" << bias.GetByOffset("channel") << ") - sum_final * channel_scale;\n" + << " " << output.SetByOffset("workgroup_idx", "output_value_t(output_element_t(channel_scale), output_element_t(channel_shift))") << ";\n" << " }\n"; return Status::OK(); } @@ -110,7 +112,7 @@ Status InstanceNormProgramNHWC::GenerateShaderCode(ShaderHelper& shader) const { << "let input_value = " << input.GetByOffset("global_idx") << ";\n"; if (components_ > 1) { shader.MainFunctionBody() << "for (var i : u32 = 0; i < uniforms.components; i = i + 1) {\n" - << " let scale_sift = " << channel_scale_shift.GetByOffset("scale_offset + i") << ";\n" + << " let scale_sift = " << channel_scale_shift.GetByOffset("uniforms.components * scale_offset + i") << ";\n" << " scale[i] = input_element_t(scale_sift.x);\n" << " shift[i] = input_element_t(scale_sift.y);\n" << "}\n";