diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index b95b7b3aa6d..e48042c4620 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -14,6 +14,8 @@ #include #include +#include "test_utils.h" + #include // @@ -201,26 +203,6 @@ void test_reference_linear_qcs4w( ASSERT_TRUE(at::allclose(out, out_ref)); } -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - case c10::kByte: - return vkapi::kByte; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_vulkan_linear_qga4w_impl( const int B, const int M, diff --git a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp index 534bb577e7a..eebbb89ab40 100644 --- a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp +++ b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp @@ -14,6 +14,8 @@ #include #include +#include "test_utils.h" + #include // @@ -55,26 +57,6 @@ std::pair rotary_embedding_impl( // Test functions // -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - case c10::kByte: - return vkapi::kByte; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_reference( const int n_heads = 4, const int n_kv_heads = 2, diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index 772039eda6a..79b679674a5 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -18,6 +18,8 @@ #include #include +#include "test_utils.h" + #include #include @@ -261,24 +263,6 @@ void test_reference_sdpa( } } -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_vulkan_sdpa( const int start_input_pos, const int base_sequence_len, diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 5c9afa40762..6fcf2d83538 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -142,6 +142,28 @@ def define_common_targets(is_fbcode = False): platforms = get_platforms(), ) + runtime.cxx_library( + name = "test_utils", + srcs = [ + "test_utils.cpp", + ], + headers = [ + "test_utils.h", + ], + exported_headers = [ + "test_utils.h", + ], + deps = [ + "//executorch/backends/vulkan:vulkan_graph_runtime", + "//executorch/runtime/core/exec_aten:lib", + runtime.external_dep_location("libtorch"), + ], + visibility = [ + "//executorch/backends/vulkan/test/op_tests/...", + "@EXECUTORCH_CLIENTS", + ], + ) + define_test_targets( "compute_graph_op_tests", src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]" @@ -150,9 +172,20 @@ def define_common_targets(is_fbcode = False): define_test_targets( "sdpa_test", extra_deps = [ + ":test_utils", "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", "//executorch/extension/tensor:tensor", ] ) - define_test_targets("linear_weight_int4_test") - define_test_targets("rotary_embedding_test") + define_test_targets( + "linear_weight_int4_test", + extra_deps = [ + ":test_utils", + ] + ) + define_test_targets( + "rotary_embedding_test", + extra_deps = [ + ":test_utils", + ] + ) diff --git a/backends/vulkan/test/op_tests/test_utils.cpp b/backends/vulkan/test/op_tests/test_utils.cpp new file mode 100644 index 00000000000..196f079be2c --- /dev/null +++ b/backends/vulkan/test/op_tests/test_utils.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "test_utils.h" + +#include + +executorch::aten::ScalarType at_scalartype_to_et_scalartype( + at::ScalarType dtype) { + using ScalarType = executorch::aten::ScalarType; + switch (dtype) { + case at::kByte: + return ScalarType::Byte; + case at::kChar: + return ScalarType::Char; + case at::kShort: + return ScalarType::Short; + case at::kInt: + return ScalarType::Int; + case at::kLong: + return ScalarType::Long; + case at::kHalf: + return ScalarType::Half; + case at::kFloat: + return ScalarType::Float; + case at::kDouble: + return ScalarType::Double; + default: + throw std::runtime_error("Unsupported dtype"); + } +} + +std::string scalar_type_name(c10::ScalarType dtype) { + switch (dtype) { + case c10::kLong: + return "c10::kLong"; + case c10::kShort: + return "c10::kShort"; + case c10::kComplexHalf: + return "c10::kComplexHalf"; + case c10::kComplexFloat: + return "c10::kComplexFloat"; + case c10::kComplexDouble: + return "c10::kComplexDouble"; + case c10::kBool: + return "c10::kBool"; + case c10::kQInt8: + return "c10::kQInt8"; + case c10::kQUInt8: + return "c10::kQUInt8"; + case c10::kQInt32: + return "c10::kQInt32"; + case c10::kBFloat16: + return "c10::kBFloat16"; + case c10::kQUInt4x2: + return "c10::kQUInt4x2"; + case c10::kQUInt2x4: + return "c10::kQUInt2x4"; + case c10::kFloat: + return "c10::kFloat"; + case c10::kHalf: + return "c10::kHalf"; + case c10::kInt: + return "c10::kInt"; + case c10::kChar: + return "c10::kChar"; + case c10::kByte: + return "c10::kByte"; + case c10::kDouble: + return "c10::kDouble"; + case c10::kUInt16: + return "c10::kUInt16"; + case c10::kBits16: + return "c10::kBits16"; + default: + return "Unknown(" + std::to_string(static_cast(dtype)) + ")"; + } +} + +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + using namespace vkcompute; + switch (at_scalartype) { + case c10::kHalf: + return vkapi::kHalf; + case c10::kFloat: + return vkapi::kFloat; + case c10::kDouble: + return vkapi::kDouble; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + return vkapi::kLong; + case c10::kChar: + return vkapi::kChar; + case c10::kByte: + return vkapi::kByte; + case c10::kShort: + return vkapi::kShort; + case c10::kUInt16: + return vkapi::kUInt16; + default: + VK_THROW( + "Unsupported at::ScalarType: ", + scalar_type_name(at_scalartype), + " (", + static_cast(at_scalartype), + ")"); + } +} diff --git a/backends/vulkan/test/op_tests/test_utils.h b/backends/vulkan/test/op_tests/test_utils.h new file mode 100644 index 00000000000..369767007e0 --- /dev/null +++ b/backends/vulkan/test/op_tests/test_utils.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +/** + * Convert at::ScalarType to executorch::ScalarType + */ +executorch::aten::ScalarType at_scalartype_to_et_scalartype( + at::ScalarType dtype); + +/** + * Get the string name of a c10::ScalarType for better error messages + */ +std::string scalar_type_name(c10::ScalarType dtype); + +/** + * Convert c10::ScalarType to vkcompute::vkapi::ScalarType + */ +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype);