Skip to content

Commit 1eed125

Browse files
yipjustinfacebook-github-bot
authored andcommitted
aten.view_copy (#3129)
Summary: Pull Request resolved: #3129 aten.view_copy, supporting all packing. Using SS-JIA's idea to do a direct lookup. ghstack-source-id: 223111187 Reviewed By: SS-JIA Differential Revision: D56281400 fbshipit-source-id: 355493fc18c015523672665e7c1c37a4c92debdd
1 parent 523c2cb commit 1eed125

File tree

6 files changed

+209
-4
lines changed

6 files changed

+209
-4
lines changed

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,21 @@
88

99
#define divup4(x) ((x + 3) / 4)
1010

11-
#define to_buffer_i(idx, sizes) \
12-
idx.x + idx.y* sizes.x + idx.z* sizes.y* sizes.x + \
13-
idx.w* sizes.z* sizes.y* sizes.x;
11+
// Input: idx is a ivec4 user-level coordinate, sizes is the tensor shape
12+
// Output: buffer_idx in the continuous nchw-buffer.
13+
#define to_buffer_i(idx, sizes) \
14+
(idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \
15+
idx.w * sizes.z * sizes.y * sizes.x)
16+
17+
// Inverse of to_buffer_i
18+
// Input: buffer_idx in the continuous nchw-buffer, sizes is the tensor shape
19+
// Output: ivec4 user-level coorindate
20+
#define from_buffer_i(buf_i, sizes) \
21+
ivec4( \
22+
buf_i % sizes.x, \
23+
(buf_i / (sizes.x)) % sizes.y, \
24+
(buf_i / (sizes.x * sizes.y)) % sizes.z, \
25+
(buf_i / (sizes.x * sizes.y * sizes.z)))
1426

1527
#define get_packed_dim_C_packed(vec) vec.z
1628
#define get_packed_dim_W_packed(vec) vec.x
@@ -20,6 +32,8 @@
2032
#define get_packed_stride_W_packed(vec) (1)
2133
#define get_packed_stride_H_packed(vec) (vec.x)
2234

35+
// Input: pos is a texture position, sizes is a pack-aligned size.
36+
// Output: a user-level (w, h, c, n) coordinate
2337
#define to_tensor_idx_C_packed(pos, sizes) \
2438
ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z)
2539

@@ -29,6 +43,9 @@
2943
#define to_tensor_idx_H_packed(pos, sizes) \
3044
ivec4(pos.x, (pos.y * 4), pos.z % sizes.z, pos.z / sizes.z)
3145

46+
// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned
47+
// size.
48+
// Output: texture location
3249
#define to_texture_pos_C_packed(idx, sizes) \
3350
ivec3(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4)
3451

@@ -38,6 +55,19 @@
3855
#define to_texture_pos_H_packed(idx, sizes) \
3956
ivec3(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z))
4057

58+
// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned
59+
// size with the index in the texel.
60+
// Output: ivec4, xyz is the texture position, w is the element index in the
61+
// texel.
62+
#define to_texture_pos_elem_C_packed(idx, sizes) \
63+
ivec4(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4, idx.z % 4)
64+
65+
#define to_texture_pos_elem_W_packed(idx, sizes) \
66+
ivec4(idx.x / 4, idx.y, (idx.z + idx.w * sizes.z), idx.x % 4)
67+
68+
#define to_texture_pos_elem_H_packed(idx, sizes) \
69+
ivec4(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z), idx.y % 4)
70+
4171
// Given a buffer(1-D) index cur, compute a new index where the corresponding
4272
// tensor(N-D)'s adjacent dimensions are swapped. The parameters x,y and plane
4373
// describe sizes. As an example, let's say we want to swap dimensions 0,1 for a
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
20+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21+
22+
#define VEC4_T ${texel_type(DTYPE)}
23+
24+
#define to_tensor_idx to_tensor_idx_${PACKING}
25+
#define to_texture_pos_elem to_texture_pos_elem_${PACKING}
26+
#define get_packed_stride get_packed_stride_${PACKING}
27+
28+
layout(set = 0, binding = 2) uniform PRECISION restrict OutGpuSizes {
29+
uvec4 out_gpu_sizes;
30+
};
31+
32+
layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes {
33+
uvec4 out_cpu_sizes;
34+
};
35+
36+
layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes {
37+
uvec4 in_gpu_sizes;
38+
};
39+
40+
layout(set = 0, binding = 5) uniform PRECISION restrict InCpuSizes {
41+
uvec4 in_cpu_sizes;
42+
};
43+
44+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
45+
46+
47+
void main() {
48+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
49+
const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes);
50+
51+
if (all(greaterThanEqual(out_tensor_idx, out_gpu_sizes))) {
52+
return;
53+
}
54+
55+
// Assume there is a virtual continous buffer in nchw format. From the output
56+
// pos, we first calculate the index in the virual buffer, and then calculate
57+
// the input position from the indx.
58+
59+
const uint base_index = to_buffer_i(out_tensor_idx, out_cpu_sizes);
60+
const uvec4 buf_indices =
61+
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes);
62+
63+
VEC4_T value;
64+
// Need to look up the 4 values in the output texel separately.
65+
for (int i=0; i<4; i++) {
66+
ivec4 user_coor = from_buffer_i(buf_indices[i], in_cpu_sizes);
67+
68+
ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes);
69+
70+
VEC4_T intex = VEC4_T(texelFetch(image_in, in_pos_elem.xyz, 0));
71+
72+
value[i] = intex[in_pos_elem.w];
73+
}
74+
75+
imageStore(image_out, out_pos, value);
76+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
view:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
PACKING:
10+
- VALUE: C_packed
11+
- VALUE: W_packed
12+
- VALUE: H_packed
13+
shader_variants:
14+
- NAME: view
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
14+
15+
namespace vkcompute {
16+
17+
void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
18+
vTensorPtr t_in = graph.get_tensor(in);
19+
vTensorPtr t_out = graph.get_tensor(out);
20+
21+
std::string kernel_name = "view";
22+
kernel_name.reserve(kShaderNameReserve);
23+
add_dtype_suffix(kernel_name, *t_out);
24+
add_memory_layout_suffix(kernel_name, *t_out);
25+
26+
api::utils::uvec3 global_size = t_out->extents();
27+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
28+
29+
graph.execute_nodes().emplace_back(new ExecuteNode(
30+
graph,
31+
VK_KERNEL_FROM_STR(kernel_name),
32+
global_size,
33+
local_size,
34+
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
35+
{t_out->gpu_sizes_ubo(),
36+
t_out->cpu_sizes_ubo(),
37+
t_in->gpu_sizes_ubo(),
38+
t_in->cpu_sizes_ubo()}));
39+
}
40+
41+
void view(ComputeGraph& graph, const std::vector<ValueRef>& args) {
42+
// Note: The second argument size_ref is not used here. Since the output
43+
// tensor's size have been determined during compilation.
44+
return add_view_node(graph, args[0], args[2]);
45+
}
46+
47+
REGISTER_OPERATORS {
48+
VK_REGISTER_OP(aten.view_copy.default, view);
49+
}
50+
51+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,33 @@ def get_permute_inputs():
194194
return test_suite
195195

196196

197+
def get_view_inputs():
198+
test_suite = VkTestSuite(
199+
[
200+
((3, 4, 5), [1, 1, -1]),
201+
((3, 4, 5), [1, -1, 1]),
202+
((3, 4, 5), [-1, 1, 1]),
203+
((8, 7, 2, 3), [4, 3, 7, 4]),
204+
((8, 7, 2, 3), [7, -1, 2, 1]),
205+
((8, 7, 2, 3), [1, 1, 1, -1]),
206+
((8, 7, 2, 3), [-1]),
207+
((2, 3, 3, 7), [2, -1, 1, 1]),
208+
((3, 5, 2, 7), [7, -1, 2, 1]),
209+
((2, 2, 8, 6), [2, 6, -1, 1]),
210+
((2, 2, 8, 6), [6, -1, 1]),
211+
((S1, S2, S1, S2), [S2, -1, 1, S1]),
212+
((S1, S2, S1, S2), [S1, 1, -1, S2]),
213+
((S1, S2, S1, S2), [-1, 1, S1, S2]),
214+
]
215+
)
216+
test_suite.layouts = [
217+
"api::kWidthPacked",
218+
"api::kHeightPacked",
219+
"api::kChannelsPacked",
220+
]
221+
return test_suite
222+
223+
197224
test_suites = {
198225
"aten.add.Tensor": get_binary_elementwise_inputs(),
199226
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -208,4 +235,5 @@ def get_permute_inputs():
208235
"aten.select_copy.int": get_select_int_inputs(),
209236
"aten.permute.default": get_permute_inputs(),
210237
"aten.permute_copy.default": get_permute_inputs(),
238+
"aten.view_copy.default": get_view_inputs(),
211239
}

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,16 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str:
105105
for size in arg_sizes_or_val:
106106
name_str += str(size) + "x"
107107
name_str = name_str[:-1]
108+
# minus sign is a invalid char for test case. change to "n".
109+
name_str = name_str.replace("-", "n")
110+
108111
elif isinstance(arg_sizes_or_val, list):
109112
for size in arg_sizes_or_val:
110113
name_str += str(size) + "c"
111114
name_str = name_str[:-1]
115+
# minus sign is a invalid char for test case. change to "n".
116+
name_str = name_str.replace("-", "n")
117+
112118
else:
113119
name_str += str(arg_sizes_or_val).replace(".", "p")
114120
return name_str
@@ -234,7 +240,7 @@ def generate_suite_cpp(self) -> str:
234240
235241
// from_blob doesn't take ownership of data. Hence must create a copy as
236242
// "values" will go out of scope.
237-
return at::from_blob(values.data(), sizes, dtype).detach().clone();
243+
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
238244
}}
239245
240246
{test_suites_cpp}

0 commit comments

Comments
 (0)