Skip to content

Commit 4361d39

Browse files
authored
[ET-VK] Implementation of to_dim_order_copy (#15677)
Title says it all! Previously, to_dim_order_copy was handled by removing the op. However, this is not possible if the op is modifying the dtype of the original tensor, so these instances of the op would be skipped by the partitioner. This diff adds an implementation dtype conversion, which allows to_dim_order_copy to be lowered. Differential Revision: [D86340341](https://our.internmc.facebook.com/intern/diff/D86340341/)
1 parent 2783db4 commit 4361d39

File tree

9 files changed

+214
-61
lines changed

9 files changed

+214
-61
lines changed

backends/vulkan/_passes/remove_redundant_ops.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass):
3131
exir_ops.edge.aten.lift_fresh_copy.default,
3232
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
3333
exir_ops.edge.dim_order_ops._clone_dim_order.default,
34+
exir_ops.edge.aten.expand_copy.default,
3435
}
3536

3637
def __init__(self) -> None:
3738
super(RemoveRedundantOpsTransform, self).__init__()
3839

3940
def _should_remove(self, node: torch.fx.Node) -> bool:
40-
if node.target in self.redundant_ops:
41-
return True
42-
43-
# Only remove to_copy if dtype does not change. Otherwise, memory format changes
44-
# will be handled internally by the backend.
45-
if (
46-
node.target == exir_ops.edge.aten._to_copy.default
47-
or node.target == torch.ops.aten._to_copy.default
48-
):
49-
src_dtype = node.meta["val"].dtype
50-
# pyre-ignore
51-
dst_dtype = node.args[0].meta["val"].dtype
52-
return src_dtype == dst_dtype
53-
54-
return False
41+
if node.target not in self.redundant_ops:
42+
return False
43+
44+
orig_node = node.args[0]
45+
assert isinstance(orig_node, torch.fx.Node)
46+
47+
src_dtype = orig_node.meta["val"].dtype
48+
dst_dtype = node.meta["val"].dtype
49+
50+
# Do not remove if the op is converting the dtype.
51+
if src_dtype != dst_dtype:
52+
return False
53+
54+
src_shape = orig_node.meta["val"].shape
55+
dst_shape = node.meta["val"].shape
56+
57+
return src_shape == dst_shape
5558

5659
def _remove(self, graph_module: torch.fx.GraphModule) -> None:
5760
for node in graph_module.graph.nodes:
5861
if not self._should_remove(node):
5962
continue
6063

61-
with graph_module.graph.inserting_after(node):
62-
node.replace_all_uses_with(node.args[0])
64+
node.replace_all_uses_with(node.args[0])
6365

6466
graph_module.graph.eliminate_dead_code()
6567

backends/vulkan/op_registry.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,12 @@
77
# pyre-unsafe
88

99
import operator
10-
1110
from typing import Any, Callable, Dict, List, Optional, Union
1211

1312
import executorch.backends.vulkan.custom_ops_lib # noqa
14-
1513
import executorch.backends.vulkan.utils as utils
16-
1714
import torch
18-
1915
from executorch.exir.dialects._ops import ops as exir_ops
20-
2116
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2217
from torch._subclasses.fake_tensor import FakeTensor
2318

@@ -129,6 +124,7 @@ def update_features_impl(op: OpKey):
129124
# Symbolic integer ops
130125
torch.ops.aten.sym_size.int,
131126
operator.add,
127+
operator.sub,
132128
operator.lt,
133129
operator.gt,
134130
operator.ge,
@@ -297,27 +293,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
297293

298294
@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
299295
def register_to_copy_dim_order_op():
300-
# Currently there is no "real" implementation for to_dim_order_copy, but it can be
301-
# removed as long as the operator is not changing the dtype, i.e. the operator call
302-
# is modifying the dim order only. Therefore, check that the input and output dtypes
303-
# are the same, if so the operator is safe to remove.
304-
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
305-
in_arg = node.args[0]
306-
if not isinstance(in_arg, torch.fx.Node):
307-
return False
308-
309-
in_tensor = in_arg.meta.get("val", None)
310-
out_tensor = node.meta.get("val", None)
311-
312-
if in_tensor.dtype != out_tensor.dtype:
313-
return False
314-
315-
return True
316-
317296
return OpFeatures(
318-
inputs_storage=utils.ANY_STORAGE,
297+
inputs_storage=utils.ANY_BUFFER,
319298
supports_resize=True,
320-
are_node_inputs_supported_fn=check_dim_order_copy_node,
321299
)
322300

323301

backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")}
1818

1919
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2020

21+
${layout_declare_spec_const(C, "int", "all_contiguous", "0")}
22+
2123
/*
2224
* The insight behind the view operation is that the contiguous index of each
2325
* tensor element in the input and output tensors are the same.
@@ -28,17 +30,20 @@ void main() {
2830
return;
2931
}
3032

31-
TensorIndex outp_tidx;
32-
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
33+
uint inp_bufi = outp_bufi;
34+
if (all_contiguous == 0) {
35+
TensorIndex outp_tidx;
36+
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
3337

34-
// To map the output to the input, find the input element that has the same
35-
// contiguous index as the output element.
36-
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);
38+
// To map the output to the input, find the input element that has the same
39+
// contiguous index as the output element.
40+
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);
3741

38-
TensorIndex inp_tidx;
39-
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);
42+
TensorIndex inp_tidx;
43+
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);
4044

41-
const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
45+
inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
46+
}
4247

4348
t_outp[outp_bufi] = t_inp[inp_bufi];
4449
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#version 450 core
2+
3+
#define PRECISION ${PRECISION}
4+
5+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
6+
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
7+
8+
${define_required_extensions(IN_DTYPE)}
9+
${define_required_extensions(OUT_DTYPE)}
10+
11+
layout(std430) buffer;
12+
13+
#include "indexing.glslh"
14+
15+
${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)}
16+
${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)}
17+
18+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
19+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
20+
21+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
22+
23+
${layout_declare_spec_const(C, "int", "all_contiguous", "0")}
24+
25+
/*
26+
* The insight behind the view_convert operation is that the contiguous index of each
27+
* tensor element in the input and output tensors are the same, but the data types
28+
* may be different and need conversion.
29+
*/
30+
void main() {
31+
const uint outp_bufi = gl_GlobalInvocationID.x;
32+
if (outp_bufi >= numel(outp)) {
33+
return;
34+
}
35+
36+
uint inp_bufi = outp_bufi;
37+
38+
if (all_contiguous == 0) {
39+
TensorIndex outp_tidx;
40+
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
41+
42+
// To map the output to the input, find the input element that has the same
43+
// contiguous index as the output element.
44+
const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx);
45+
46+
TensorIndex inp_tidx;
47+
contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx);
48+
49+
inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
50+
}
51+
52+
// Convert data type from input to output
53+
t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]);
54+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
view_convert_buffer:
8+
parameter_names_with_default_values:
9+
IN_DTYPE: float
10+
OUT_DTYPE: float
11+
STORAGE: buffer
12+
generate_variant_forall:
13+
combination:
14+
parameter_names: [IN_DTYPE, OUT_DTYPE]
15+
combos:
16+
- parameter_values: [int32, float]
17+
- parameter_values: [int32, half]
18+
- parameter_values: [uint8, float]
19+
- parameter_values: [uint8, half]
20+
- parameter_values: [uint8, int32]
21+
shader_variants:
22+
- NAME: view_convert_buffer

backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ void resize_unsqueeze_node(
6767

6868
std::vector<int64_t> out_sizes = graph->sizes_of(in);
6969

70+
std::vector<int64_t> unsqueezed_dims;
71+
72+
if (graph->val_is_int_list(dims_ref)) {
73+
const IntListPtr dims = graph->get_int_list(dims_ref);
74+
for (int64_t d : *dims) {
75+
unsqueezed_dims.push_back(d);
76+
}
77+
} else {
78+
const int64_t dim = graph->extract_scalar<int64_t>(dims_ref);
79+
unsqueezed_dims.push_back(dim);
80+
}
81+
7082
// Insert singleton dimensions at the specified positions
7183
for (auto dim : dims_vec) {
7284
int64_t d = dim;

backends/vulkan/runtime/graph/ops/impl/View.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ void resize_view_node(
6060
}
6161
}
6262

63+
void resize_to_dim_order_copy_node(
64+
ComputeGraph* graph,
65+
const std::vector<ArgGroup>& args,
66+
const std::vector<ValueRef>& extra_args) {
67+
const ValueRef out = args.at(0).refs.at(0);
68+
const ValueRef in = args.at(1).refs.at(0);
69+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
70+
graph->virtual_resize(out, in_sizes);
71+
}
72+
6373
void add_view_node(
6474
ComputeGraph& graph,
6575
ValueRef in,
@@ -98,6 +108,11 @@ void add_view_copy_buffer_node(
98108
std::string kernel_name = "view_buffer";
99109
add_dtype_suffix(kernel_name, graph.dtype_of(out));
100110

111+
bool all_contiguous = graph.is_contiguous_buffer_tensor(in) &&
112+
graph.is_contiguous_buffer_tensor(out);
113+
114+
int32_t all_contiguous_int = all_contiguous ? 1 : 0;
115+
101116
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
102117
graph,
103118
VK_KERNEL_FROM_STR(kernel_name),
@@ -110,7 +125,41 @@ void add_view_copy_buffer_node(
110125
// Push Constants
111126
{},
112127
// Specialization Constants
128+
{all_contiguous_int},
129+
// Resize Args
130+
resize_args,
131+
// Resizing Logic
132+
resize_fn));
133+
}
134+
135+
void add_view_copy_convert_buffer_node(
136+
ComputeGraph& graph,
137+
ValueRef in,
138+
ValueRef out,
139+
const std::vector<ValueRef>& resize_args,
140+
const ExecuteNode::ResizeFunction& resize_fn) {
141+
std::string kernel_name = "view_convert_buffer";
142+
add_dtype_suffix(kernel_name, graph.dtype_of(in));
143+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
144+
145+
bool all_contiguous = graph.is_contiguous_buffer_tensor(in) &&
146+
graph.is_contiguous_buffer_tensor(out);
147+
148+
int32_t all_contiguous_int = all_contiguous ? 1 : 0;
149+
150+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
151+
graph,
152+
VK_KERNEL_FROM_STR(kernel_name),
153+
default_pick_global_wg_size,
154+
default_pick_local_wg_size,
155+
// Inputs and Outputs
156+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
157+
// Parameter Buffers
158+
{graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)},
159+
// Push Constants
113160
{},
161+
// Specialization Constants
162+
{all_contiguous_int},
114163
// Resize Args
115164
resize_args,
116165
// Resizing Logic
@@ -132,8 +181,38 @@ void view(ComputeGraph& graph, const std::vector<ValueRef>& args) {
132181
return add_view_node(graph, in, sizes, out);
133182
}
134183

184+
void to_dim_order_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
185+
int args_idx = 0;
186+
const ValueRef in = args.at(args_idx++);
187+
const ValueRef dtype = args.at(args_idx++);
188+
(void)dtype;
189+
const ValueRef layout = args.at(args_idx++);
190+
(void)layout;
191+
const ValueRef device = args.at(args_idx++);
192+
(void)device;
193+
const ValueRef pin_memory = args.at(args_idx++);
194+
(void)pin_memory;
195+
const ValueRef non_blocking = args.at(args_idx++);
196+
(void)non_blocking;
197+
const ValueRef dim_order = args.at(args_idx++);
198+
(void)dim_order;
199+
200+
const ValueRef out = args.at(args_idx++);
201+
202+
VK_CHECK_COND(graph.is_buffer_storage(in) && graph.is_buffer_storage(out));
203+
204+
if (graph.dtype_of(in) == graph.dtype_of(out)) {
205+
return add_view_copy_buffer_node(
206+
graph, in, out, {}, resize_to_dim_order_copy_node);
207+
}
208+
209+
return add_view_copy_convert_buffer_node(
210+
graph, in, out, {}, resize_to_dim_order_copy_node);
211+
}
212+
135213
REGISTER_OPERATORS {
136214
VK_REGISTER_OP(aten.view_copy.default, view);
215+
VK_REGISTER_OP(dim_order_ops._to_dim_order_copy.default, to_dim_order_copy);
137216
}
138217

139218
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/View.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ void add_view_copy_buffer_node(
2424
const std::vector<ValueRef>& resize_args,
2525
const ExecuteNode::ResizeFunction& resize_fn);
2626

27+
/*
28+
* Dispatches the view_convert_buffer compute shader. This can be used to
29+
* implement ops that preserve the "contiguous" indexes of elements between the
30+
* input and output while converting between different data types such as
31+
* view_copy with dtype conversion.
32+
*/
33+
void add_view_copy_convert_buffer_node(
34+
ComputeGraph& graph,
35+
ValueRef in,
36+
ValueRef out,
37+
const std::vector<ValueRef>& resize_args,
38+
const ExecuteNode::ResizeFunction& resize_fn);
39+
2740
void add_view_node(
2841
ComputeGraph& graph,
2942
ValueRef in,

0 commit comments

Comments
 (0)