Skip to content

Commit 306dfbb

Browse files
qjia7fs-eire
authored andcommitted
[webgpu] And int64 to cast (#25610)
This pull request extends the WebGPU execution provider to support int64 data type casting in the `Cast` operator, with conditional support based on whether graph capture is enabled. It refactors kernel registration to allow toggling int64 support and updates the shader code and kernel logic to handle int64 tensors efficiently. It's part of the work to enable graph capture in phi4 #25868
1 parent 8944bbe commit 306dfbb

File tree

4 files changed

+114
-79
lines changed

4 files changed

+114
-79
lines changed

onnxruntime/core/providers/webgpu/shader_variable.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std::
378378
ORT_THROW("Invalid type");
379379
break;
380380
case onnxruntime::webgpu::ProgramVariableDataType::Int64:
381-
ss << name_ << "[" << offset << "]=vec2<u32>(u32(" << value << "), select(0u, 0xFFFFFFFFu, " << value << " < 0));";
381+
ss << name_ << "[" << offset << "]=vec2<u32>(u32(" << value << "), select(0u, 0xFFFFFFFFu, i32(" << value << ") < 0));";
382382
break;
383383
case onnxruntime::webgpu::ProgramVariableDataType::Uint64:
384384
ss << name_ << "[" << offset << "]=vec2<u32>(u32(" << value << "), 0u);";

onnxruntime/core/providers/webgpu/tensor/cast.cc

Lines changed: 88 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -11,88 +11,47 @@ namespace onnxruntime {
1111
namespace webgpu {
1212

1313
namespace {
14-
const std::vector<MLDataType>& CastOpTypeConstraints() {
15-
// currently support boolean, integer and float types that explicitly allowed in WGSL:
14+
const std::vector<MLDataType>& CastOpTypeConstraints(bool enable_graph_capture) {
15+
// Base types that are always supported - boolean, integer and float types that explicitly allowed in WGSL:
1616
// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
17-
//
18-
static std::vector<MLDataType> types{
17+
static std::vector<MLDataType> base_types{
1918
DataTypeImpl::GetTensorType<MLFloat16>(),
2019
DataTypeImpl::GetTensorType<float>(),
2120
DataTypeImpl::GetTensorType<int32_t>(),
2221
DataTypeImpl::GetTensorType<uint32_t>(),
2322
DataTypeImpl::GetTensorType<bool>()};
24-
return types;
23+
24+
if (enable_graph_capture) {
25+
static std::vector<MLDataType> types_with_int64 = []() {
26+
auto types = base_types;
27+
types.push_back(DataTypeImpl::GetTensorType<int64_t>());
28+
return types;
29+
}();
30+
return types_with_int64;
31+
} else {
32+
return base_types;
33+
}
2534
}
2635
} // namespace
2736

28-
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
29-
Cast,
30-
kOnnxDomain,
31-
6, 8,
32-
kWebGpuExecutionProvider,
33-
(*KernelDefBuilder::Create())
34-
.TypeConstraint("T1", CastOpTypeConstraints())
35-
.TypeConstraint("T2", CastOpTypeConstraints()),
36-
Cast);
37-
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
38-
Cast,
39-
kOnnxDomain,
40-
9, 12,
41-
kWebGpuExecutionProvider,
42-
(*KernelDefBuilder::Create())
43-
.TypeConstraint("T1", CastOpTypeConstraints())
44-
.TypeConstraint("T2", CastOpTypeConstraints()),
45-
Cast);
46-
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
47-
Cast,
48-
kOnnxDomain,
49-
13, 18,
50-
kWebGpuExecutionProvider,
51-
(*KernelDefBuilder::Create())
52-
.TypeConstraint("T1", CastOpTypeConstraints())
53-
.TypeConstraint("T2", CastOpTypeConstraints()),
54-
Cast);
55-
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
56-
Cast,
57-
kOnnxDomain,
58-
19, 20,
59-
kWebGpuExecutionProvider,
60-
(*KernelDefBuilder::Create())
61-
.TypeConstraint("T1", CastOpTypeConstraints())
62-
.TypeConstraint("T2", CastOpTypeConstraints()),
63-
Cast);
64-
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
65-
Cast,
66-
kOnnxDomain,
67-
21, 22,
68-
kWebGpuExecutionProvider,
69-
(*KernelDefBuilder::Create())
70-
.TypeConstraint("T1", CastOpTypeConstraints())
71-
.TypeConstraint("T2", CastOpTypeConstraints()),
72-
Cast);
73-
ONNX_OPERATOR_KERNEL_EX(
74-
Cast,
75-
kOnnxDomain,
76-
23,
77-
kWebGpuExecutionProvider,
78-
(*KernelDefBuilder::Create())
79-
.TypeConstraint("T1", CastOpTypeConstraints())
80-
.TypeConstraint("T2", CastOpTypeConstraints()),
81-
Cast);
82-
8337
Status Cast::ComputeInternal(ComputeContext& context) const {
8438
const auto* input_tensor = context.Input(0);
8539
auto* output_tensor = context.Output(0, input_tensor->Shape());
8640
int64_t size = input_tensor->Shape().Size();
8741
if (size == 0) {
8842
return Status::OK();
8943
}
44+
bool is_from_int64 = input_tensor->DataType() == DataTypeImpl::GetType<int64_t>();
45+
const int in_components = is_from_int64 ? 1 : 4;
46+
const int out_components = to_ == ONNX_NAMESPACE::TensorProto_DataType_INT64 ? 1 : 4;
9047
uint32_t vec_size = onnxruntime::narrow<uint32_t>((size + 3) / 4);
48+
uint32_t in_vec_size = onnxruntime::narrow<uint32_t>(in_components == 1 ? size : vec_size);
49+
uint32_t out_vec_size = onnxruntime::narrow<uint32_t>(out_components == 1 ? size : vec_size);
9150

92-
CastProgram program{to_};
51+
CastProgram program{to_, is_from_int64};
9352
program
94-
.AddInput({input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
95-
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4})
53+
.AddInput({input_tensor, ProgramTensorMetadataDependency::Type, {in_vec_size}, in_components})
54+
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None, {out_vec_size}, out_components})
9655
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
9756
.AddUniformVariables({
9857
{static_cast<uint32_t>(vec_size)},
@@ -121,15 +80,78 @@ Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const {
12180
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
12281
expression = "vec4<bool>(a)";
12382
break;
83+
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
84+
expression = "int32(a)";
85+
break;
12486
default:
12587
ORT_NOT_IMPLEMENTED("Cast to type ", to_, " is not supported.");
12688
}
127-
sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")
128-
<< " let a = " << input.GetByOffset("global_idx") << ";\n "
129-
<< output.SetByOffset("global_idx", expression);
89+
90+
sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size");
91+
if (is_from_int64_) {
92+
sh.MainFunctionBody() << " let a0 = " << input.GetByOffset("global_idx * 4") << ";\n"
93+
<< " let a1 = " << input.GetByOffset("global_idx * 4 + 1") << ";\n"
94+
<< " let a2 = " << input.GetByOffset("global_idx * 4 + 2") << ";\n"
95+
<< " let a3 = " << input.GetByOffset("global_idx * 4 + 3") << ";\n"
96+
<< " let a = vec4<i32>(a0, a1, a2, a3);\n";
97+
} else {
98+
sh.MainFunctionBody() << " let a = " << input.GetByOffset("global_idx") << ";\n";
99+
}
100+
if (to_ == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
101+
sh.MainFunctionBody() << output.SetByOffset("global_idx * 4", "a.x") << "\n"
102+
<< output.SetByOffset("global_idx * 4 + 1", "a.y") << "\n"
103+
<< output.SetByOffset("global_idx * 4 + 2", "a.z") << "\n"
104+
<< output.SetByOffset("global_idx * 4 + 3", "a.w") << "\n";
105+
} else {
106+
sh.MainFunctionBody() << output.SetByOffset("global_idx", expression);
107+
}
130108

131109
return Status::OK();
132110
}
133111

112+
template <int StartVersion, int EndVersion>
113+
KernelCreateInfo CreateCastKernelInfo(bool enable_graph_capture) {
114+
const auto& type_constraints = CastOpTypeConstraints(enable_graph_capture);
115+
116+
KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
117+
out = std::make_unique<Cast>(info);
118+
return Status::OK();
119+
};
120+
121+
if constexpr (StartVersion == EndVersion) {
122+
// Non-versioned kernel
123+
return {
124+
KernelDefBuilder()
125+
.SetName("Cast")
126+
.SetDomain(kOnnxDomain)
127+
.SinceVersion(StartVersion)
128+
.Provider(kWebGpuExecutionProvider)
129+
.TypeConstraint("T1", type_constraints)
130+
.TypeConstraint("T2", type_constraints)
131+
.Build(),
132+
kernel_create_fn};
133+
} else {
134+
// Versioned kernel
135+
return {
136+
KernelDefBuilder()
137+
.SetName("Cast")
138+
.SetDomain(kOnnxDomain)
139+
.SinceVersion(StartVersion, EndVersion)
140+
.Provider(kWebGpuExecutionProvider)
141+
.TypeConstraint("T1", type_constraints)
142+
.TypeConstraint("T2", type_constraints)
143+
.Build(),
144+
kernel_create_fn};
145+
}
146+
}
147+
148+
// Explicit template instantiations
149+
template KernelCreateInfo CreateCastKernelInfo<6, 8>(bool);
150+
template KernelCreateInfo CreateCastKernelInfo<9, 12>(bool);
151+
template KernelCreateInfo CreateCastKernelInfo<13, 18>(bool);
152+
template KernelCreateInfo CreateCastKernelInfo<19, 20>(bool);
153+
template KernelCreateInfo CreateCastKernelInfo<21, 22>(bool);
154+
template KernelCreateInfo CreateCastKernelInfo<23>(bool);
155+
134156
} // namespace webgpu
135157
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/tensor/cast.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,24 @@
33

44
#pragma once
55

6+
#include "core/framework/kernel_registry.h"
7+
#include "core/framework/op_kernel.h"
68
#include "core/providers/webgpu/webgpu_kernel.h"
79

810
namespace onnxruntime {
911
namespace webgpu {
1012

1113
class CastProgram final : public Program<CastProgram> {
1214
public:
13-
CastProgram(int32_t to) : Program{"Cast"}, to_{to} {}
15+
CastProgram(int32_t to, bool is_from_int64) : Program{"Cast"}, to_{to}, is_from_int64_{is_from_int64} {}
1416

1517
Status GenerateShaderCode(ShaderHelper& sh) const override;
1618

1719
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});
1820

1921
private:
2022
int32_t to_;
23+
bool is_from_int64_;
2124
};
2225

2326
class Cast final : public WebGpuKernel {
@@ -37,5 +40,9 @@ class Cast final : public WebGpuKernel {
3740
int32_t to_;
3841
};
3942

43+
// Create Cast kernel info with appropriate type constraints based on graph capture support
44+
template <int StartVersion, int EndVersion = StartVersion>
45+
KernelCreateInfo CreateCastKernelInfo(bool enable_graph_capture);
46+
4047
} // namespace webgpu
4148
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "core/providers/webgpu/data_transfer.h"
2929
#include "core/providers/webgpu/external_data_loader.h"
3030
#include "core/providers/webgpu/webgpu_profiler.h"
31+
#include "core/providers/webgpu/tensor/cast.h"
3132

3233
namespace onnxruntime {
3334

@@ -417,7 +418,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD
417418
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
418419
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ScatterND);
419420

420-
std::unique_ptr<KernelRegistry> RegisterKernels() {
421+
std::unique_ptr<KernelRegistry> RegisterKernels(bool enable_graph_capture = false) {
421422
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
422423

423424
static const BuildKernelCreateInfoFn function_table[] = {
@@ -464,13 +465,6 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
464465
KERNEL_CREATE_INFO(13, Tanh),
465466
KERNEL_CREATE_INFO(1, Not),
466467

467-
KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast),
468-
KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast),
469-
KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast),
470-
KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast),
471-
KERNEL_CREATE_INFO_VERSIONED(21, 22, Cast),
472-
KERNEL_CREATE_INFO(23, Cast),
473-
474468
// // activations
475469
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip)>,
476470
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip)>,
@@ -771,6 +765,14 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
771765
}
772766
}
773767

768+
// Register Cast kernels with conditional int64 support based on graph capture
769+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<6, 8>(enable_graph_capture)));
770+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<9, 12>(enable_graph_capture)));
771+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<13, 18>(enable_graph_capture)));
772+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<19, 20>(enable_graph_capture)));
773+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<21, 22>(enable_graph_capture)));
774+
ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<23>(enable_graph_capture)));
775+
774776
#ifndef DISABLE_CONTRIB_OPS
775777
Status status = ::onnxruntime::contrib::webgpu::RegisterWebGpuContribKernels(*kernel_registry);
776778
ORT_ENFORCE(status.IsOK(), "Failed to register WebGPU contrib kernels: " + status.ErrorMessage());
@@ -905,9 +907,13 @@ std::vector<std::unique_ptr<ComputeCapability>> WebGpuExecutionProvider::GetCapa
905907
}
906908

907909
std::shared_ptr<KernelRegistry> WebGpuExecutionProvider::GetKernelRegistry() const {
908-
static std::shared_ptr<KernelRegistry> registry = webgpu::RegisterKernels();
909-
910-
return registry;
910+
if (enable_graph_capture_) {
911+
static std::shared_ptr<KernelRegistry> registry = webgpu::RegisterKernels(true);
912+
return registry;
913+
} else {
914+
static std::shared_ptr<KernelRegistry> registry = webgpu::RegisterKernels(false);
915+
return registry;
916+
}
911917
}
912918

913919
std::unique_ptr<onnxruntime::IDataTransfer> WebGpuExecutionProvider::GetDataTransfer() const {

0 commit comments

Comments
 (0)