16
16
17
17
namespace vkcompute {
18
18
19
+ enum class UpsampleMode : int { NEAREST, BILINEAR };
20
+
19
21
void resize_upsample_nearest2d_node (
20
22
ComputeGraph* graph,
21
23
const std::vector<ArgGroup>& args,
@@ -39,19 +41,12 @@ void resize_upsample_nearest2d_node(
39
41
out->virtual_resize (out_sizes);
40
42
}
41
43
42
- // ExecuTorch-Vulkan framework to add node
43
- // Args:
44
- // in: will be converted from NCHW input tensor to 3D ARGB representation in
45
- // openGL (via ExecuTorch) output_sizes: optional 2D array of targetting
46
- // output size of H and W dimensions. >= input sizes;
47
-
48
- // will be computed if only given the scale_factors.
49
- // scale_factors: optional 2D array of scale factors for H and W dimensions.
50
- // Will be computed if only given the output_sizes.
51
44
void add_upsample_nearest2d_node (
52
45
ComputeGraph& graph,
46
+ const UpsampleMode mode,
53
47
const ValueRef in,
54
48
const ValueRef output_sizes,
49
+ const ValueRef align_corners,
55
50
const ValueRef scale_factors,
56
51
const ValueRef out) {
57
52
if (graph.val_is_none (output_sizes) && graph.val_is_none (scale_factors)) {
@@ -63,36 +58,61 @@ void add_upsample_nearest2d_node(
63
58
" Invalid input, must provide ONLY one of output_sizes or scale_factors" );
64
59
}
65
60
66
- vTensorPtr t_in = graph.get_tensor (in);
67
- utils::uvec3 input_sizes = t_in->logical_limits ();
61
+ int align_corners_val = 0 ;
62
+ if (is_valid (align_corners) && graph.get_bool (align_corners)) {
63
+ align_corners_val = 1 ;
64
+ }
65
+
66
+ utils::uvec3 in_limits = graph.logical_limits_of (in);
67
+ utils::uvec3 out_limits = graph.logical_limits_of (out);
68
+
69
+ uint32_t out_width = out_limits[0u ];
70
+ uint32_t out_height = out_limits[1u ];
68
71
69
- utils::ivec2 input_size = {
70
- utils::safe_downcast< int32_t >(input_sizes[ 0 ]),
71
- utils::safe_downcast< int32_t >(input_sizes[ 1 ])};
72
- utils::vec2 rev_scales = {
73
- utils::safe_downcast< float >( 1.0 ), utils::safe_downcast< float >( 1.0 )} ;
72
+ float scale_factor_x = float (in_limits[ 0u ]) / float (out_width);
73
+ float scale_factor_y = float (in_limits[ 1u ]) / float (out_height);
74
+
75
+ float recip_scale_factor_x = 1 . 0f / scale_factor_x;
76
+ float recip_scale_factor_y = 1 . 0f / scale_factor_y ;
74
77
75
- // Reverse scale factors that pre-computed before GLSL.
76
78
if (!graph.val_is_none (output_sizes)) {
77
- auto output_size_ref = graph.get_int_list (output_sizes);
78
- rev_scales = {
79
- utils::safe_downcast<float >(
80
- (float )input_size[0 ] / output_size_ref->at (1 )),
81
- utils::safe_downcast<float >(
82
- (float )input_size[1 ] / output_size_ref->at (0 ))};
79
+ IntListPtr output_size_ref = graph.get_int_list (output_sizes);
80
+ out_width = output_size_ref->at (1 );
81
+ out_height = output_size_ref->at (0 );
82
+
83
+ VK_CHECK_COND (out_width == out_limits[0u ]);
84
+ VK_CHECK_COND (out_height == out_limits[1u ]);
85
+
86
+ } else {
87
+ DoubleListPtr scales = graph.get_double_list (scale_factors);
88
+ scale_factor_x = scales->at (1 );
89
+ scale_factor_y = scales->at (0 );
83
90
91
+ VK_CHECK_COND (in_limits[0u ] * scale_factor_x == out_width);
92
+ VK_CHECK_COND (in_limits[1u ] * scale_factor_y == out_height);
93
+ }
94
+
95
+ if (align_corners_val == 1 ) {
96
+ recip_scale_factor_x = float (in_limits[0u ] - 1 ) / float (out_width - 1 );
97
+ recip_scale_factor_y = float (in_limits[1u ] - 1 ) / float (out_height - 1 );
84
98
} else {
85
- auto scales = graph.get_double_list (scale_factors);
86
- rev_scales = {
87
- utils::safe_downcast<float >(1.0 / scales->at (1 )),
88
- utils::safe_downcast<float >(1.0 / scales->at (0 ))};
99
+ recip_scale_factor_x = float (in_limits[0u ]) / float (out_width);
100
+ recip_scale_factor_y = float (in_limits[1u ]) / float (out_height);
89
101
}
90
102
91
- vTensorPtr t_out = graph. get_tensor (out) ;
103
+ utils::vec2 recip_scales = {recip_scale_factor_x, recip_scale_factor_y} ;
92
104
93
- std::string kernel_name ( " upsample_nearest2d " ) ;
105
+ std::string kernel_name;
94
106
kernel_name.reserve (kShaderNameReserve );
95
- add_dtype_suffix (kernel_name, *t_out);
107
+ switch (mode) {
108
+ case UpsampleMode::NEAREST:
109
+ kernel_name = " upsample_nearest2d" ;
110
+ break ;
111
+ case UpsampleMode::BILINEAR:
112
+ kernel_name = " upsample_bilinear2d" ;
113
+ break ;
114
+ }
115
+ add_dtype_suffix (kernel_name, graph.dtype_of (out));
96
116
97
117
graph.execute_nodes ().emplace_back (new DispatchNode (
98
118
graph,
@@ -103,21 +123,44 @@ void add_upsample_nearest2d_node(
103
123
{{out, vkapi::MemoryAccessType::WRITE},
104
124
{in, vkapi::MemoryAccessType::READ}},
105
125
// Shader params buffers
106
- {t_out-> logical_limits_ubo (),
107
- graph.create_params_buffer (input_size ),
108
- graph.create_params_buffer (rev_scales )},
126
+ {graph. logical_limits_ubo (out ),
127
+ graph.logical_limits_ubo (in ),
128
+ graph.create_params_buffer (recip_scales )},
109
129
// Specialization Constants
110
- {},
130
+ {align_corners_val },
111
131
resize_upsample_nearest2d_node,
112
132
{output_sizes, scale_factors}));
113
133
}
114
134
115
- void upsample (ComputeGraph& graph, const std::vector<ValueRef>& args) {
116
- return add_upsample_nearest2d_node (graph, args[0 ], args[1 ], args[2 ], args[3 ]);
135
+ void upsample_nearest2d (
136
+ ComputeGraph& graph,
137
+ const std::vector<ValueRef>& args) {
138
+ return add_upsample_nearest2d_node (
139
+ graph,
140
+ UpsampleMode::NEAREST,
141
+ args[0 ],
142
+ args[1 ],
143
+ kDummyValueRef ,
144
+ args[2 ],
145
+ args[3 ]);
146
+ }
147
+
148
+ void upsample_bilinear2d (
149
+ ComputeGraph& graph,
150
+ const std::vector<ValueRef>& args) {
151
+ return add_upsample_nearest2d_node (
152
+ graph,
153
+ UpsampleMode::BILINEAR,
154
+ args[0 ],
155
+ args[1 ],
156
+ args[2 ],
157
+ args[3 ],
158
+ args[4 ]);
117
159
}
118
160
119
161
REGISTER_OPERATORS {
120
- VK_REGISTER_OP (aten.upsample_nearest2d .vec , upsample);
162
+ VK_REGISTER_OP (aten.upsample_nearest2d .vec , upsample_nearest2d);
163
+ VK_REGISTER_OP (aten.upsample_bilinear2d .vec , upsample_bilinear2d);
121
164
}
122
165
123
166
} // namespace vkcompute
0 commit comments