@@ -23,8 +23,7 @@ void check_args(
23
23
const api::vTensor& in,
24
24
const std::vector<int64_t >& repeats,
25
25
const api::vTensor& out) {
26
- VK_CHECK_COND (check_packed_dim_is (in, WHCN::kChannelsDim ));
27
- VK_CHECK_COND (check_packed_dim_is (out, WHCN::kChannelsDim ));
26
+ VK_CHECK_COND (check_same_packed_dim (in, out));
28
27
29
28
VK_CHECK_COND (in.storage_type () == out.storage_type ());
30
29
if (in.storage_type () == utils::kTexture2D ) {
@@ -59,147 +58,29 @@ void check_args(
59
58
60
59
} // namespace
61
60
62
- void add_repeat_channel_node (
63
- ComputeGraph& graph,
64
- ValueRef in,
65
- int64_t repeat_channel,
66
- ValueRef out,
67
- utils::ivec3& running_range) {
68
- vTensorPtr t_in = graph.get_tensor (in);
69
- vTensorPtr t_out = graph.get_tensor (out);
70
-
71
- std::string kernel_name = " repeat_channel" ;
72
- kernel_name.reserve (kShaderNameReserve );
73
- add_dtype_suffix (kernel_name, *t_out);
74
-
75
- const std::vector<int64_t >& in_sizes = t_in->sizes ();
76
-
77
- int32_t in_width = utils::safe_downcast<int32_t >(dim_at<kWidth4D >(in_sizes));
78
- int32_t in_height =
79
- utils::safe_downcast<int32_t >(dim_at<kHeight4D >(in_sizes));
80
- int32_t in_channel =
81
- utils::safe_downcast<int32_t >(dim_at<kChannel4D >(in_sizes));
82
- int32_t in_batch = utils::safe_downcast<int32_t >(dim_at<kBatch4D >(in_sizes));
83
-
84
- int32_t out_channel = repeat_channel * in_channel;
85
-
86
- utils::ivec4 out_whcn_sizes{in_width, in_height, out_channel, in_batch};
87
-
88
- utils::ivec4 in_whcn_sizes{in_width, in_height, in_channel, in_batch};
89
-
90
- // Channel packed global work ids
91
- running_range[2 ] = out_whcn_sizes[3 ] * utils::div_up_4 (out_whcn_sizes[2 ]);
92
- utils::uvec3 global_size = utils::make_uvec3 (running_range);
93
- utils::uvec3 local_size = adaptive_work_group_size (global_size);
94
-
95
- const struct Block final {
96
- utils::ivec4 out_sizes;
97
- utils::ivec4 in_size;
98
- } repeat_channel_args{
99
- out_whcn_sizes,
100
- in_whcn_sizes,
101
- };
102
-
103
- auto shader = VK_KERNEL_FROM_STR (kernel_name);
104
-
105
- graph.execute_nodes ().emplace_back (new DispatchNode (
106
- graph,
107
- VK_KERNEL_FROM_STR (kernel_name),
108
- global_size,
109
- local_size,
110
- // Inputs and Outputs
111
- {{out, vkapi::MemoryAccessType::WRITE},
112
- {in, vkapi::MemoryAccessType::READ}},
113
- // Parameter buffers
114
- {graph.create_params_buffer (repeat_channel_args)},
115
- // Specialization Constants
116
- {SV (t_out->packed_dim ())}));
117
- }
118
-
119
61
void add_repeat_node (
120
62
ComputeGraph& graph,
121
63
ValueRef in,
122
64
ValueRef repeats_ref,
123
65
ValueRef out) {
124
- std::vector<int64_t > repeats = *(graph.get_int_list (repeats_ref));
66
+ const std::vector<int64_t > repeats = *(graph.get_int_list (repeats_ref));
125
67
126
68
vTensorPtr t_in = graph.get_tensor (in);
127
69
vTensorPtr t_out = graph.get_tensor (out);
128
70
check_args (*t_in, repeats, *t_out);
129
71
130
- // In this function, we expand the dimensions in the following order:
131
- // 1. Channel
132
- // 2. Width
133
- // 3. Height
134
- // 4. Batch
135
- // After expanding a dimension, we will update the "running_range" since we
136
- // will need to copy the "expanded" area.
137
-
138
- utils::ivec3 running_range = t_in->logical_limits ();
139
-
140
- const std::vector<int64_t >& in_sizes = t_in->sizes ();
141
-
142
- // Since we use channel packing, repeating the channel dimension is the most
143
- // complicated and time-consuming, as we need to reason over misaligned
144
- // channels. Hence we expand it first to minimize cost. Also, in this first
145
- // dimension, we copy over the input texure to the output. In subsequent
146
- // dimensions, we read and write from the same tensor.
147
-
148
- if (int64_t channel_repeat = dim_at<kChannel4D >(repeats);
149
- channel_repeat == 1 ) {
150
- // If no repeat, short-cut to a direct copy
151
- utils::ivec4 src_offset{0 , 0 , 0 , 0 };
152
- utils::ivec4 dst_offset{0 , 0 , 0 , 0 };
153
-
154
- add_copy_offset_node (
155
- graph, in, running_range, src_offset, dst_offset, out, false , false );
156
-
157
- } else {
158
- add_repeat_channel_node (graph, in, channel_repeat, out, running_range);
159
- }
160
-
161
- // TODO: refactor width, height, and batch into a common helper function.
162
- // Width
163
- if (int64_t width_repeat = dim_at<kWidth4D >(repeats); width_repeat > 1 ) {
164
- utils::ivec4 src_offset{0 , 0 , 0 , 0 };
165
-
166
- for (int i = 1 ; i < width_repeat; ++i) {
167
- utils::ivec4 dst_offset{i * dim_at<kWidth4D >(in_sizes), 0 , 0 , 0 };
168
-
169
- add_copy_offset_node (
170
- graph, out, running_range, src_offset, dst_offset, out, true , false );
171
- }
172
-
173
- running_range[0 ] = running_range[0 ] * width_repeat;
174
- }
175
-
176
- // Height
177
- if (int64_t height_repeat = dim_at<kHeight4D >(repeats); height_repeat > 1 ) {
178
- utils::ivec4 src_offset{0 , 0 , 0 , 0 };
179
-
180
- for (int i = 1 ; i < height_repeat; ++i) {
181
- utils::ivec4 dst_offset = {0 , i * dim_at<kHeight4D >(in_sizes), 0 , 0 };
182
-
183
- add_copy_offset_node (
184
- graph, out, running_range, src_offset, dst_offset, out, true , false );
185
- }
186
-
187
- running_range[1 ] = running_range[1 ] * height_repeat;
188
- }
189
-
190
- // Batch
191
- if (int64_t batch_repeat = dim_at<kBatch4D >(repeats); batch_repeat > 1 ) {
192
- utils::ivec4 src_offset{0 , 0 , 0 , 0 };
193
-
194
- for (int i = 1 ; i < batch_repeat; ++i) {
195
- utils::ivec4 dst_offset = {0 , 0 , i * running_range[2 ], 0 };
196
-
197
- add_copy_offset_node (
198
- graph, out, running_range, src_offset, dst_offset, out, true , false );
199
- }
200
-
201
- running_range[2 ] = running_range[2 ] * batch_repeat;
202
- }
72
+ const utils::ivec4 src_offset{
73
+ dim_at<kWidth4D >(t_in->sizes ()),
74
+ dim_at<kHeight4D >(t_in->sizes ()),
75
+ dim_at<kChannel4D >(t_in->sizes ()),
76
+ dim_at<kBatch4D >(t_in->sizes ())};
77
+ const utils::ivec4 dst_offset{
78
+ dim_at<kWidth4D >(repeats),
79
+ dim_at<kHeight4D >(repeats),
80
+ dim_at<kChannel4D >(repeats),
81
+ dim_at<kBatch4D >(repeats)};
82
+ add_copy_packed_dim_offset_node (
83
+ graph, in, t_out->logical_limits (), src_offset, dst_offset, out, true );
203
84
}
204
85
205
86
void repeat (ComputeGraph& graph, const std::vector<ValueRef>& args) {
0 commit comments