Skip to content

Commit 9a74207

Browse files
authored
[ET-VK] Add support for aten::upsample_bilinear2d ATen op (#10363)
Title says it all! Differential Revision: [D73261394](https://our.internmc.facebook.com/intern/diff/D73261394/)
1 parent c179f0d commit 9a74207

File tree

6 files changed

+183
-93
lines changed

6 files changed

+183
-93
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def register_view_op(features: OpFeatures):
540540
exir_ops.edge.aten.ones.default,
541541
exir_ops.edge.aten.ones_like.default,
542542
exir_ops.edge.aten.upsample_nearest2d.vec,
543+
exir_ops.edge.aten.upsample_bilinear2d.vec,
543544
exir_ops.edge.aten.zeros.default,
544545
exir_ops.edge.aten.zeros_like.default,
545546
exir_ops.edge.et_vk.grid_priors.default,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
15+
layout(std430) buffer;
16+
17+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
18+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
19+
${layout_declare_ubo(B, "ivec3", "out_limits")}
20+
${layout_declare_ubo(B, "ivec3", "in_limits")}
21+
${layout_declare_ubo(B, "vec2", "recip_scales")}
22+
23+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
24+
25+
layout(constant_id = 3) const int align_corners = 0;
26+
27+
void main() {
28+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
29+
30+
if (any(greaterThanEqual(pos, out_limits))) {
31+
return;
32+
}
33+
34+
ivec2 max_in_xy = in_limits.xy - 1;
35+
vec2 scaled_xy;
36+
37+
if (align_corners == 1) {
38+
scaled_xy = pos.xy * recip_scales;
39+
} else {
40+
scaled_xy = (pos.xy + 0.5) * recip_scales - 0.5;
41+
}
42+
43+
$if MODE == "nearest":
44+
const ivec2 ipos = clamp(ivec2(round(scaled_xy)), ivec2(0), max_in_xy);
45+
VEC4_T out_tex = texelFetch(t_in, ivec3(ipos, pos.z), 0);
46+
$elif MODE == "bilinear":
47+
vec2 upper_xy = ceil(scaled_xy);
48+
vec2 lower_xy = floor(scaled_xy);
49+
50+
// Clamp coordinates to valid input range
51+
upper_xy = clamp(upper_xy, ivec2(0), max_in_xy);
52+
lower_xy = clamp(lower_xy, ivec2(0), max_in_xy);
53+
54+
// Calculate interpolation weights
55+
vec2 interp_weights = (scaled_xy - lower_xy);
56+
57+
// Sample the four nearest texels
58+
VEC4_T sample00 = texelFetch(t_in, ivec3(lower_xy.x, lower_xy.y, pos.z), 0);
59+
VEC4_T sample10 = texelFetch(t_in, ivec3(upper_xy.x, lower_xy.y, pos.z), 0);
60+
VEC4_T sample01 = texelFetch(t_in, ivec3(lower_xy.x, upper_xy.y, pos.z), 0);
61+
VEC4_T sample11 = texelFetch(t_in, ivec3(upper_xy.x, upper_xy.y, pos.z), 0);
62+
63+
// Perform bilinear interpolation
64+
VEC4_T out_tex = mix(
65+
mix(sample00, sample10, interp_weights.x),
66+
mix(sample01, sample11, interp_weights.x),
67+
interp_weights.y
68+
);
69+
70+
imageStore(t_out, pos, out_tex);
71+
}

backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/upsample_2d.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
upsample_nearest2d:
7+
upsample_2d:
88
parameter_names_with_default_values:
9-
NDIM: 3
109
DTYPE: float
11-
PACKING: C_packed
1210
STORAGE: texture3d
11+
MODE: nearest
1312
generate_variant_forall:
1413
DTYPE:
1514
- VALUE: half
1615
- VALUE: float
1716
shader_variants:
1817
- NAME: upsample_nearest2d
18+
- NAME: upsample_bilinear2d
19+
MODE: bilinear

backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.glsl

Lines changed: 0 additions & 39 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/impl/Upsample.cpp

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
namespace vkcompute {
1818

19+
enum class UpsampleMode : int { NEAREST, BILINEAR };
20+
1921
void resize_upsample_nearest2d_node(
2022
ComputeGraph* graph,
2123
const std::vector<ArgGroup>& args,
@@ -39,19 +41,12 @@ void resize_upsample_nearest2d_node(
3941
out->virtual_resize(out_sizes);
4042
}
4143

42-
// ExecuTorch-Vulkan framework to add node
43-
// Args:
44-
// in: will be converted from NCHW input tensor to 3D ARGB representation in
45-
// openGL (via ExecuTorch) output_sizes: optional 2D array of targetting
46-
// output size of H and W dimensions. >= input sizes;
47-
48-
// will be computed if only given the scale_factors.
49-
// scale_factors: optional 2D array of scale factors for H and W dimensions.
50-
// Will be computed if only given the output_sizes.
5144
void add_upsample_nearest2d_node(
5245
ComputeGraph& graph,
46+
const UpsampleMode mode,
5347
const ValueRef in,
5448
const ValueRef output_sizes,
49+
const ValueRef align_corners,
5550
const ValueRef scale_factors,
5651
const ValueRef out) {
5752
if (graph.val_is_none(output_sizes) && graph.val_is_none(scale_factors)) {
@@ -63,36 +58,61 @@ void add_upsample_nearest2d_node(
6358
"Invalid input, must provide ONLY one of output_sizes or scale_factors");
6459
}
6560

66-
vTensorPtr t_in = graph.get_tensor(in);
67-
utils::uvec3 input_sizes = t_in->logical_limits();
61+
int align_corners_val = 0;
62+
if (is_valid(align_corners) && graph.get_bool(align_corners)) {
63+
align_corners_val = 1;
64+
}
65+
66+
utils::uvec3 in_limits = graph.logical_limits_of(in);
67+
utils::uvec3 out_limits = graph.logical_limits_of(out);
68+
69+
uint32_t out_width = out_limits[0u];
70+
uint32_t out_height = out_limits[1u];
6871

69-
utils::ivec2 input_size = {
70-
utils::safe_downcast<int32_t>(input_sizes[0]),
71-
utils::safe_downcast<int32_t>(input_sizes[1])};
72-
utils::vec2 rev_scales = {
73-
utils::safe_downcast<float>(1.0), utils::safe_downcast<float>(1.0)};
72+
float scale_factor_x = float(in_limits[0u]) / float(out_width);
73+
float scale_factor_y = float(in_limits[1u]) / float(out_height);
74+
75+
float recip_scale_factor_x = 1.0f / scale_factor_x;
76+
float recip_scale_factor_y = 1.0f / scale_factor_y;
7477

75-
// Reverse scale factors that pre-computed before GLSL.
7678
if (!graph.val_is_none(output_sizes)) {
77-
auto output_size_ref = graph.get_int_list(output_sizes);
78-
rev_scales = {
79-
utils::safe_downcast<float>(
80-
(float)input_size[0] / output_size_ref->at(1)),
81-
utils::safe_downcast<float>(
82-
(float)input_size[1] / output_size_ref->at(0))};
79+
IntListPtr output_size_ref = graph.get_int_list(output_sizes);
80+
out_width = output_size_ref->at(1);
81+
out_height = output_size_ref->at(0);
82+
83+
VK_CHECK_COND(out_width == out_limits[0u]);
84+
VK_CHECK_COND(out_height == out_limits[1u]);
85+
86+
} else {
87+
DoubleListPtr scales = graph.get_double_list(scale_factors);
88+
scale_factor_x = scales->at(1);
89+
scale_factor_y = scales->at(0);
8390

91+
VK_CHECK_COND(in_limits[0u] * scale_factor_x == out_width);
92+
VK_CHECK_COND(in_limits[1u] * scale_factor_y == out_height);
93+
}
94+
95+
if (align_corners_val == 1) {
96+
recip_scale_factor_x = float(in_limits[0u] - 1) / float(out_width - 1);
97+
recip_scale_factor_y = float(in_limits[1u] - 1) / float(out_height - 1);
8498
} else {
85-
auto scales = graph.get_double_list(scale_factors);
86-
rev_scales = {
87-
utils::safe_downcast<float>(1.0 / scales->at(1)),
88-
utils::safe_downcast<float>(1.0 / scales->at(0))};
99+
recip_scale_factor_x = float(in_limits[0u]) / float(out_width);
100+
recip_scale_factor_y = float(in_limits[1u]) / float(out_height);
89101
}
90102

91-
vTensorPtr t_out = graph.get_tensor(out);
103+
utils::vec2 recip_scales = {recip_scale_factor_x, recip_scale_factor_y};
92104

93-
std::string kernel_name("upsample_nearest2d");
105+
std::string kernel_name;
94106
kernel_name.reserve(kShaderNameReserve);
95-
add_dtype_suffix(kernel_name, *t_out);
107+
switch (mode) {
108+
case UpsampleMode::NEAREST:
109+
kernel_name = "upsample_nearest2d";
110+
break;
111+
case UpsampleMode::BILINEAR:
112+
kernel_name = "upsample_bilinear2d";
113+
break;
114+
}
115+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
96116

97117
graph.execute_nodes().emplace_back(new DispatchNode(
98118
graph,
@@ -103,21 +123,44 @@ void add_upsample_nearest2d_node(
103123
{{out, vkapi::MemoryAccessType::WRITE},
104124
{in, vkapi::MemoryAccessType::READ}},
105125
// Shader params buffers
106-
{t_out->logical_limits_ubo(),
107-
graph.create_params_buffer(input_size),
108-
graph.create_params_buffer(rev_scales)},
126+
{graph.logical_limits_ubo(out),
127+
graph.logical_limits_ubo(in),
128+
graph.create_params_buffer(recip_scales)},
109129
// Specialization Constants
110-
{},
130+
{align_corners_val},
111131
resize_upsample_nearest2d_node,
112132
{output_sizes, scale_factors}));
113133
}
114134

115-
void upsample(ComputeGraph& graph, const std::vector<ValueRef>& args) {
116-
return add_upsample_nearest2d_node(graph, args[0], args[1], args[2], args[3]);
135+
void upsample_nearest2d(
136+
ComputeGraph& graph,
137+
const std::vector<ValueRef>& args) {
138+
return add_upsample_nearest2d_node(
139+
graph,
140+
UpsampleMode::NEAREST,
141+
args[0],
142+
args[1],
143+
kDummyValueRef,
144+
args[2],
145+
args[3]);
146+
}
147+
148+
void upsample_bilinear2d(
149+
ComputeGraph& graph,
150+
const std::vector<ValueRef>& args) {
151+
return add_upsample_nearest2d_node(
152+
graph,
153+
UpsampleMode::BILINEAR,
154+
args[0],
155+
args[1],
156+
args[2],
157+
args[3],
158+
args[4]);
117159
}
118160

119161
REGISTER_OPERATORS {
120-
VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample);
162+
VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample_nearest2d);
163+
VK_REGISTER_OP(aten.upsample_bilinear2d.vec, upsample_bilinear2d);
121164
}
122165

123166
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -430,21 +430,34 @@ def get_native_layer_norm_inputs():
430430
return test_suite
431431

432432

433-
@register_test_suite("aten.upsample_nearest2d.vec")
434433
def get_upsample_inputs():
435-
test_suite = VkTestSuite(
436-
[
437-
# (input tensor shape, output 2D image size (H, W), output scaling factors)
438-
((2, 2, 2, 2), None, [1, 1]),
439-
((1, 1, 2, 2), None, [2, 2]),
440-
((1, 1, 2, 2), None, [2, 4]),
441-
((1, 1, 2, 2), None, [4, 2]),
442-
((1, 1, 2, 2), [2, 2], None),
443-
((1, 1, 2, 2), [2, 4], None),
444-
((1, 1, 2, 2), [3, 2], None),
445-
]
446-
)
447-
return test_suite
434+
inputs_list = [
435+
# (input tensor shape, output 2D image size (H, W), output scaling factors)
436+
((2, 2, 2, 2), None, [1, 1]),
437+
((1, 1, 2, 2), None, [2, 2]),
438+
((1, 1, 2, 2), None, [2, 4]),
439+
((1, 1, 2, 2), None, [4, 2]),
440+
((1, 1, 2, 2), [2, 2], None),
441+
((1, 1, 2, 2), [2, 4], None),
442+
((1, 1, 2, 2), [3, 2], None),
443+
]
444+
return inputs_list
445+
446+
447+
@register_test_suite("aten.upsample_nearest2d.vec")
448+
def get_upsample_nearest2d_inputs():
449+
inputs_list = get_upsample_inputs()
450+
return VkTestSuite(inputs_list)
451+
452+
453+
@register_test_suite("aten.upsample_bilinear2d.vec")
454+
def get_upsample_bilinear2d_inputs():
455+
base_inputs_list = get_upsample_inputs()
456+
inputs_list = []
457+
for input_case in base_inputs_list:
458+
inputs_list.append((input_case[0], input_case[1], False, input_case[2]))
459+
inputs_list.append((input_case[0], input_case[1], True, input_case[2]))
460+
return VkTestSuite(inputs_list)
448461

449462

450463
@register_test_suite(["aten.full.default", "aten.full_like.default"])

0 commit comments

Comments
 (0)