Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions onnxruntime/core/providers/webgpu/nn/instance_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,25 @@ namespace onnxruntime {
namespace webgpu {

Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseIndicesTypeAlias);
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

shader.AdditionalImplementation() << "var<workgroup> workgroup_shared_sum : array<x_value_t, " << workgroup_size_ << ">;\n"
<< "var<workgroup> workgroup_shared_squared_sum : array<x_value_t, " << workgroup_size_ << ">;\n"
shader.AdditionalImplementation() << "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n"
<< "var<workgroup> workgroup_shared_sum : array<f32_val_t, " << workgroup_size_ << ">;\n"
<< "var<workgroup> workgroup_shared_squared_sum : array<f32_val_t, " << workgroup_size_ << ">;\n"
<< "const workgroup_size = " << workgroup_size_ << ";\n";

shader.MainFunctionBody() << " let batch = workgroup_idx / uniforms.x_shape[1];\n"
<< " let channel = workgroup_idx % uniforms.x_shape[1];\n"
<< " let hight = uniforms.x_shape[2];\n"
<< " // initialize workgroup memory<< \n"
<< " var sum = x_value_t(0);\n"
<< " var squared_sum = x_value_t(0);\n"
<< " var sum = f32_val_t(0);\n"
<< " var squared_sum = f32_val_t(0);\n"
<< " for (var h = local_idx; h < hight; h += workgroup_size) {\n"
<< " let indices = x_indices_t(batch, channel, h);\n"
<< " let value =" << input.GetByIndices("indices") << ";\n"
<< " let value = f32_val_t(" << input.GetByIndices("indices") << ");\n"
<< " sum += value;\n"
<< " squared_sum += value * value;\n"
<< " }\n"
Expand All @@ -44,12 +46,12 @@ Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader)
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (local_idx == 0) {\n"
<< " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / x_element_t(hight * " << components_ << ");\n"
<< " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / x_element_t(hight * " << components_ << ");\n"
<< " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + x_element_t(" << std::to_string(epsilon_) << "));\n"
<< " let channel_scale = inv_std_dev * " << scale.GetByOffset("channel") << ";\n"
<< " let channel_shift = " << bias.GetByOffset("channel") << " - sum_final * channel_scale;\n"
<< " " << output.SetByOffset("workgroup_idx", "output_value_t(channel_scale, channel_shift)") << ";\n"
<< " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n"
<< " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n"
<< " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(" << std::to_string(epsilon_) << "));\n"
<< " let channel_scale = inv_std_dev * f32(" << scale.GetByOffset("channel") << ");\n"
<< " let channel_shift = f32(" << bias.GetByOffset("channel") << ") - sum_final * channel_scale;\n"
<< " " << output.SetByOffset("workgroup_idx", "output_value_t(output_element_t(channel_scale), output_element_t(channel_shift))") << ";\n"
<< " }\n";
return Status::OK();
}
Expand Down Expand Up @@ -110,7 +112,7 @@ Status InstanceNormProgramNHWC::GenerateShaderCode(ShaderHelper& shader) const {
<< "let input_value = " << input.GetByOffset("global_idx") << ";\n";
if (components_ > 1) {
shader.MainFunctionBody() << "for (var i : u32 = 0; i < uniforms.components; i = i + 1) {\n"
<< " let scale_sift = " << channel_scale_shift.GetByOffset("scale_offset + i") << ";\n"
<< " let scale_sift = " << channel_scale_shift.GetByOffset("uniforms.components * scale_offset + i") << ";\n"
<< " scale[i] = input_element_t(scale_sift.x);\n"
<< " shift[i] = input_element_t(scale_sift.y);\n"
<< "}\n";
Expand Down
Loading