Skip to content

Commit e60f17a

Browse files
authored
[ET-VK][Ops] common test utils for converting aten types to vulkan types
Differential Revision: D76464550 Pull Request resolved: #11575
1 parent fa08661 commit e60f17a

File tree

6 files changed

+187
-60
lines changed

6 files changed

+187
-60
lines changed

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1515
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1616

17+
#include "test_utils.h"
18+
1719
#include <cassert>
1820

1921
//
@@ -201,26 +203,6 @@ void test_reference_linear_qcs4w(
201203
ASSERT_TRUE(at::allclose(out, out_ref));
202204
}
203205

204-
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
205-
using namespace vkcompute;
206-
switch (at_scalartype) {
207-
case c10::kFloat:
208-
return vkapi::kFloat;
209-
case c10::kHalf:
210-
return vkapi::kHalf;
211-
case c10::kInt:
212-
return vkapi::kInt;
213-
case c10::kLong:
214-
return vkapi::kInt;
215-
case c10::kChar:
216-
return vkapi::kChar;
217-
case c10::kByte:
218-
return vkapi::kByte;
219-
default:
220-
VK_THROW("Unsupported at::ScalarType!");
221-
}
222-
}
223-
224206
void test_vulkan_linear_qga4w_impl(
225207
const int B,
226208
const int M,

backends/vulkan/test/op_tests/rotary_embedding_test.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1515
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1616

17+
#include "test_utils.h"
18+
1719
#include <cassert>
1820

1921
//
@@ -55,26 +57,6 @@ std::pair<at::Tensor, at::Tensor> rotary_embedding_impl(
5557
// Test functions
5658
//
5759

58-
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
59-
using namespace vkcompute;
60-
switch (at_scalartype) {
61-
case c10::kFloat:
62-
return vkapi::kFloat;
63-
case c10::kHalf:
64-
return vkapi::kHalf;
65-
case c10::kInt:
66-
return vkapi::kInt;
67-
case c10::kLong:
68-
return vkapi::kInt;
69-
case c10::kChar:
70-
return vkapi::kChar;
71-
case c10::kByte:
72-
return vkapi::kByte;
73-
default:
74-
VK_THROW("Unsupported at::ScalarType!");
75-
}
76-
}
77-
7860
void test_reference(
7961
const int n_heads = 4,
8062
const int n_kv_heads = 2,

backends/vulkan/test/op_tests/sdpa_test.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
1919
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
2020

21+
#include "test_utils.h"
22+
2123
#include <cassert>
2224
#include <iostream>
2325

@@ -261,24 +263,6 @@ void test_reference_sdpa(
261263
}
262264
}
263265

264-
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
265-
using namespace vkcompute;
266-
switch (at_scalartype) {
267-
case c10::kFloat:
268-
return vkapi::kFloat;
269-
case c10::kHalf:
270-
return vkapi::kHalf;
271-
case c10::kInt:
272-
return vkapi::kInt;
273-
case c10::kLong:
274-
return vkapi::kInt;
275-
case c10::kChar:
276-
return vkapi::kChar;
277-
default:
278-
VK_THROW("Unsupported at::ScalarType!");
279-
}
280-
}
281-
282266
void test_vulkan_sdpa(
283267
const int start_input_pos,
284268
const int base_sequence_len,

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,28 @@ def define_common_targets(is_fbcode = False):
142142
platforms = get_platforms(),
143143
)
144144

145+
runtime.cxx_library(
146+
name = "test_utils",
147+
srcs = [
148+
"test_utils.cpp",
149+
],
150+
headers = [
151+
"test_utils.h",
152+
],
153+
exported_headers = [
154+
"test_utils.h",
155+
],
156+
deps = [
157+
"//executorch/backends/vulkan:vulkan_graph_runtime",
158+
"//executorch/runtime/core/exec_aten:lib",
159+
runtime.external_dep_location("libtorch"),
160+
],
161+
visibility = [
162+
"//executorch/backends/vulkan/test/op_tests/...",
163+
"@EXECUTORCH_CLIENTS",
164+
],
165+
)
166+
145167
define_test_targets(
146168
"compute_graph_op_tests",
147169
src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]"
@@ -150,9 +172,20 @@ def define_common_targets(is_fbcode = False):
150172
define_test_targets(
151173
"sdpa_test",
152174
extra_deps = [
175+
":test_utils",
153176
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
154177
"//executorch/extension/tensor:tensor",
155178
]
156179
)
157-
define_test_targets("linear_weight_int4_test")
158-
define_test_targets("rotary_embedding_test")
180+
define_test_targets(
181+
"linear_weight_int4_test",
182+
extra_deps = [
183+
":test_utils",
184+
]
185+
)
186+
define_test_targets(
187+
"rotary_embedding_test",
188+
extra_deps = [
189+
":test_utils",
190+
]
191+
)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "test_utils.h"
10+
11+
#include <stdexcept>
12+
13+
executorch::aten::ScalarType at_scalartype_to_et_scalartype(
14+
at::ScalarType dtype) {
15+
using ScalarType = executorch::aten::ScalarType;
16+
switch (dtype) {
17+
case at::kByte:
18+
return ScalarType::Byte;
19+
case at::kChar:
20+
return ScalarType::Char;
21+
case at::kShort:
22+
return ScalarType::Short;
23+
case at::kInt:
24+
return ScalarType::Int;
25+
case at::kLong:
26+
return ScalarType::Long;
27+
case at::kHalf:
28+
return ScalarType::Half;
29+
case at::kFloat:
30+
return ScalarType::Float;
31+
case at::kDouble:
32+
return ScalarType::Double;
33+
default:
34+
throw std::runtime_error("Unsupported dtype");
35+
}
36+
}
37+
38+
std::string scalar_type_name(c10::ScalarType dtype) {
39+
switch (dtype) {
40+
case c10::kLong:
41+
return "c10::kLong";
42+
case c10::kShort:
43+
return "c10::kShort";
44+
case c10::kComplexHalf:
45+
return "c10::kComplexHalf";
46+
case c10::kComplexFloat:
47+
return "c10::kComplexFloat";
48+
case c10::kComplexDouble:
49+
return "c10::kComplexDouble";
50+
case c10::kBool:
51+
return "c10::kBool";
52+
case c10::kQInt8:
53+
return "c10::kQInt8";
54+
case c10::kQUInt8:
55+
return "c10::kQUInt8";
56+
case c10::kQInt32:
57+
return "c10::kQInt32";
58+
case c10::kBFloat16:
59+
return "c10::kBFloat16";
60+
case c10::kQUInt4x2:
61+
return "c10::kQUInt4x2";
62+
case c10::kQUInt2x4:
63+
return "c10::kQUInt2x4";
64+
case c10::kFloat:
65+
return "c10::kFloat";
66+
case c10::kHalf:
67+
return "c10::kHalf";
68+
case c10::kInt:
69+
return "c10::kInt";
70+
case c10::kChar:
71+
return "c10::kChar";
72+
case c10::kByte:
73+
return "c10::kByte";
74+
case c10::kDouble:
75+
return "c10::kDouble";
76+
case c10::kUInt16:
77+
return "c10::kUInt16";
78+
case c10::kBits16:
79+
return "c10::kBits16";
80+
default:
81+
return "Unknown(" + std::to_string(static_cast<int>(dtype)) + ")";
82+
}
83+
}
84+
85+
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
86+
using namespace vkcompute;
87+
switch (at_scalartype) {
88+
case c10::kHalf:
89+
return vkapi::kHalf;
90+
case c10::kFloat:
91+
return vkapi::kFloat;
92+
case c10::kDouble:
93+
return vkapi::kDouble;
94+
case c10::kInt:
95+
return vkapi::kInt;
96+
case c10::kLong:
97+
return vkapi::kLong;
98+
case c10::kChar:
99+
return vkapi::kChar;
100+
case c10::kByte:
101+
return vkapi::kByte;
102+
case c10::kShort:
103+
return vkapi::kShort;
104+
case c10::kUInt16:
105+
return vkapi::kUInt16;
106+
default:
107+
VK_THROW(
108+
"Unsupported at::ScalarType: ",
109+
scalar_type_name(at_scalartype),
110+
" (",
111+
static_cast<int>(at_scalartype),
112+
")");
113+
}
114+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <string>
12+
13+
#include <ATen/ATen.h>
14+
#include <c10/core/ScalarType.h>
15+
#include <executorch/backends/vulkan/runtime/api/api.h>
16+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
17+
18+
/**
19+
* Convert at::ScalarType to executorch::ScalarType
20+
*/
21+
executorch::aten::ScalarType at_scalartype_to_et_scalartype(
22+
at::ScalarType dtype);
23+
24+
/**
25+
* Get the string name of a c10::ScalarType for better error messages
26+
*/
27+
std::string scalar_type_name(c10::ScalarType dtype);
28+
29+
/**
30+
* Convert c10::ScalarType to vkcompute::vkapi::ScalarType
31+
*/
32+
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype);

0 commit comments

Comments
 (0)