@@ -11,88 +11,47 @@ namespace onnxruntime {
1111namespace webgpu {
1212
1313namespace {
14- const std::vector<MLDataType>& CastOpTypeConstraints () {
15- // currently support boolean, integer and float types that explicitly allowed in WGSL:
14+ const std::vector<MLDataType>& CastOpTypeConstraints (bool enable_graph_capture ) {
15+ // Base types that are always supported - boolean, integer and float types that explicitly allowed in WGSL:
1616 // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
17- //
18- static std::vector<MLDataType> types{
17+ static std::vector<MLDataType> base_types{
1918 DataTypeImpl::GetTensorType<MLFloat16>(),
2019 DataTypeImpl::GetTensorType<float >(),
2120 DataTypeImpl::GetTensorType<int32_t >(),
2221 DataTypeImpl::GetTensorType<uint32_t >(),
2322 DataTypeImpl::GetTensorType<bool >()};
24- return types;
23+
24+ if (enable_graph_capture) {
25+ static std::vector<MLDataType> types_with_int64 = []() {
26+ auto types = base_types;
27+ types.push_back (DataTypeImpl::GetTensorType<int64_t >());
28+ return types;
29+ }();
30+ return types_with_int64;
31+ } else {
32+ return base_types;
33+ }
2534}
2635} // namespace
2736
28- ONNX_OPERATOR_VERSIONED_KERNEL_EX (
29- Cast,
30- kOnnxDomain ,
31- 6 , 8 ,
32- kWebGpuExecutionProvider ,
33- (*KernelDefBuilder::Create ())
34- .TypeConstraint(" T1" , CastOpTypeConstraints())
35- .TypeConstraint(" T2" , CastOpTypeConstraints()),
36- Cast);
37- ONNX_OPERATOR_VERSIONED_KERNEL_EX (
38- Cast,
39- kOnnxDomain ,
40- 9 , 12 ,
41- kWebGpuExecutionProvider ,
42- (*KernelDefBuilder::Create ())
43- .TypeConstraint(" T1" , CastOpTypeConstraints())
44- .TypeConstraint(" T2" , CastOpTypeConstraints()),
45- Cast);
46- ONNX_OPERATOR_VERSIONED_KERNEL_EX (
47- Cast,
48- kOnnxDomain ,
49- 13 , 18 ,
50- kWebGpuExecutionProvider ,
51- (*KernelDefBuilder::Create ())
52- .TypeConstraint(" T1" , CastOpTypeConstraints())
53- .TypeConstraint(" T2" , CastOpTypeConstraints()),
54- Cast);
55- ONNX_OPERATOR_VERSIONED_KERNEL_EX (
56- Cast,
57- kOnnxDomain ,
58- 19 , 20 ,
59- kWebGpuExecutionProvider ,
60- (*KernelDefBuilder::Create ())
61- .TypeConstraint(" T1" , CastOpTypeConstraints())
62- .TypeConstraint(" T2" , CastOpTypeConstraints()),
63- Cast);
64- ONNX_OPERATOR_VERSIONED_KERNEL_EX (
65- Cast,
66- kOnnxDomain ,
67- 21 , 22 ,
68- kWebGpuExecutionProvider ,
69- (*KernelDefBuilder::Create ())
70- .TypeConstraint(" T1" , CastOpTypeConstraints())
71- .TypeConstraint(" T2" , CastOpTypeConstraints()),
72- Cast);
73- ONNX_OPERATOR_KERNEL_EX (
74- Cast,
75- kOnnxDomain ,
76- 23 ,
77- kWebGpuExecutionProvider ,
78- (*KernelDefBuilder::Create ())
79- .TypeConstraint(" T1" , CastOpTypeConstraints())
80- .TypeConstraint(" T2" , CastOpTypeConstraints()),
81- Cast);
82-
8337Status Cast::ComputeInternal (ComputeContext& context) const {
8438 const auto * input_tensor = context.Input (0 );
8539 auto * output_tensor = context.Output (0 , input_tensor->Shape ());
8640 int64_t size = input_tensor->Shape ().Size ();
8741 if (size == 0 ) {
8842 return Status::OK ();
8943 }
44+ bool is_from_int64 = input_tensor->DataType () == DataTypeImpl::GetType<int64_t >();
45+ const int in_components = is_from_int64 ? 1 : 4 ;
46+ const int out_components = to_ == ONNX_NAMESPACE::TensorProto_DataType_INT64 ? 1 : 4 ;
9047 uint32_t vec_size = onnxruntime::narrow<uint32_t >((size + 3 ) / 4 );
48+ uint32_t in_vec_size = onnxruntime::narrow<uint32_t >(in_components == 1 ? size : vec_size);
49+ uint32_t out_vec_size = onnxruntime::narrow<uint32_t >(out_components == 1 ? size : vec_size);
9150
92- CastProgram program{to_};
51+ CastProgram program{to_, is_from_int64 };
9352 program
94- .AddInput ({input_tensor, ProgramTensorMetadataDependency::Type, {vec_size }, 4 })
95- .AddOutput ({output_tensor, ProgramTensorMetadataDependency::None, {vec_size }, 4 })
53+ .AddInput ({input_tensor, ProgramTensorMetadataDependency::Type, {in_vec_size }, in_components })
54+ .AddOutput ({output_tensor, ProgramTensorMetadataDependency::None, {out_vec_size }, out_components })
9655 .SetDispatchGroupSize ((vec_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
9756 .AddUniformVariables ({
9857 {static_cast <uint32_t >(vec_size)},
@@ -121,15 +80,78 @@ Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const {
12180 case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
12281 expression = " vec4<bool>(a)" ;
12382 break ;
83+ case ONNX_NAMESPACE::TensorProto_DataType_INT64:
84+ expression = " int32(a)" ;
85+ break ;
12486 default :
12587 ORT_NOT_IMPLEMENTED (" Cast to type " , to_, " is not supported." );
12688 }
127- sh.MainFunctionBody () << sh.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.vec_size" )
128- << " let a = " << input.GetByOffset (" global_idx" ) << " ;\n "
129- << output.SetByOffset (" global_idx" , expression);
89+
90+ sh.MainFunctionBody () << sh.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.vec_size" );
91+ if (is_from_int64_) {
92+ sh.MainFunctionBody () << " let a0 = " << input.GetByOffset (" global_idx * 4" ) << " ;\n "
93+ << " let a1 = " << input.GetByOffset (" global_idx * 4 + 1" ) << " ;\n "
94+ << " let a2 = " << input.GetByOffset (" global_idx * 4 + 2" ) << " ;\n "
95+ << " let a3 = " << input.GetByOffset (" global_idx * 4 + 3" ) << " ;\n "
96+ << " let a = vec4<i32>(a0, a1, a2, a3);\n " ;
97+ } else {
98+ sh.MainFunctionBody () << " let a = " << input.GetByOffset (" global_idx" ) << " ;\n " ;
99+ }
100+ if (to_ == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
101+ sh.MainFunctionBody () << output.SetByOffset (" global_idx * 4" , " a.x" ) << " \n "
102+ << output.SetByOffset (" global_idx * 4 + 1" , " a.y" ) << " \n "
103+ << output.SetByOffset (" global_idx * 4 + 2" , " a.z" ) << " \n "
104+ << output.SetByOffset (" global_idx * 4 + 3" , " a.w" ) << " \n " ;
105+ } else {
106+ sh.MainFunctionBody () << output.SetByOffset (" global_idx" , expression);
107+ }
130108
131109 return Status::OK ();
132110}
133111
112+ template <int StartVersion, int EndVersion>
113+ KernelCreateInfo CreateCastKernelInfo (bool enable_graph_capture) {
114+ const auto & type_constraints = CastOpTypeConstraints (enable_graph_capture);
115+
116+ KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
117+ out = std::make_unique<Cast>(info);
118+ return Status::OK ();
119+ };
120+
121+ if constexpr (StartVersion == EndVersion) {
122+ // Non-versioned kernel
123+ return {
124+ KernelDefBuilder ()
125+ .SetName (" Cast" )
126+ .SetDomain (kOnnxDomain )
127+ .SinceVersion (StartVersion)
128+ .Provider (kWebGpuExecutionProvider )
129+ .TypeConstraint (" T1" , type_constraints)
130+ .TypeConstraint (" T2" , type_constraints)
131+ .Build (),
132+ kernel_create_fn};
133+ } else {
134+ // Versioned kernel
135+ return {
136+ KernelDefBuilder ()
137+ .SetName (" Cast" )
138+ .SetDomain (kOnnxDomain )
139+ .SinceVersion (StartVersion, EndVersion)
140+ .Provider (kWebGpuExecutionProvider )
141+ .TypeConstraint (" T1" , type_constraints)
142+ .TypeConstraint (" T2" , type_constraints)
143+ .Build (),
144+ kernel_create_fn};
145+ }
146+ }
147+
148+ // Explicit template instantiations
149+ template KernelCreateInfo CreateCastKernelInfo<6 , 8 >(bool );
150+ template KernelCreateInfo CreateCastKernelInfo<9 , 12 >(bool );
151+ template KernelCreateInfo CreateCastKernelInfo<13 , 18 >(bool );
152+ template KernelCreateInfo CreateCastKernelInfo<19 , 20 >(bool );
153+ template KernelCreateInfo CreateCastKernelInfo<21 , 22 >(bool );
154+ template KernelCreateInfo CreateCastKernelInfo<23 >(bool );
155+
134156} // namespace webgpu
135157} // namespace onnxruntime
0 commit comments