@@ -168,6 +168,39 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n
168
168
return true ;
169
169
}
170
170
171
+ bool is_mulmat_tensors_aligned (hexagon::tensor * out) {
172
+ static_assert (DEVICE_TENSOR_MAX_DIMS == 4 , " mul_mat_f32 requires max dims 4" );
173
+ auto * src0 = out->get_src (0 );
174
+ auto * src1 = out->get_src (1 );
175
+
176
+ if (!hexagon::is_addr_aligned (src0->get_read_buffer ()) || src0->get_nb (1 ) % hexagon::kBytesPerVector ||
177
+ !hexagon::is_addr_aligned (src1->get_read_buffer ()) || src1->get_nb (1 ) % hexagon::kBytesPerVector ) {
178
+ DEVICE_LOG_DEBUG (
179
+ " mul_mat_tensors_aligned: src0: %p, src1: %p, src0.nb[1]: %ld, src1.nb[1]: %ld "
180
+ " not aligned to %zu\n " ,
181
+ src0->get_read_buffer (), src1->get_read_buffer (), (long ) src0->get_nb (1 ), (long ) src1->get_nb (1 ),
182
+ hexagon::kBytesPerVector );
183
+ return false ;
184
+ }
185
+
186
+ const auto src1_type_size = hexagon::get_type_traits (src1->get_type ()).type_size ;
187
+ if ((src1->get_ne (0 ) * src1_type_size) % hexagon::kBytesPerVector ) {
188
+ DEVICE_LOG_DEBUG (" mul_mat_tensors_aligned: src1.ne[0]: %ld, src1.type_size: %zu not aligned to %zu\n " ,
189
+ (long ) src1->get_ne (0 ), src1_type_size, hexagon::kBytesPerVector );
190
+ return false ;
191
+ }
192
+
193
+ const auto & src0_traits = hexagon::get_type_traits (src1->get_type ());
194
+ const auto src0_type_size = src0_traits.is_quantized ? sizeof (float ) : src0_traits.type_size ;
195
+ if ((src0->get_ne (0 ) * src0_type_size) % hexagon::kBytesPerVector ) {
196
+ DEVICE_LOG_DEBUG (" mul_mat_tensors_aligned: src0.ne[0]: %ld, src0.type_size: %zu not aligned to %zu\n " ,
197
+ (long ) src0->get_ne (0 ), src0_type_size, hexagon::kBytesPerVector );
198
+ return false ;
199
+ }
200
+
201
+ return true ;
202
+ }
203
+
171
204
} // namespace
172
205
173
206
namespace hexagon {
@@ -184,17 +217,34 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
184
217
return true ; // skip if no src
185
218
}
186
219
187
- // TODO: array?
188
- switch (src1->get_type ()) {
189
- case NPU_DATA_TYPE_F32:
190
- mul_mat_impl<hexagon::vec_dot_product_f32_f32>(src0, src1, out, params);
191
- return true ;
192
-
193
- case NPU_DATA_TYPE_F16:
194
- mul_mat_impl<hexagon::vec_dot_product_f16_f16>(src0, src1, out, params);
195
- return true ;
196
- default :
197
- break ;
220
+ if (is_mulmat_tensors_aligned (out)) {
221
+ DEVICE_LOG_DEBUG (" mul_mat_f32: src0 and src1 aligned\n " );
222
+
223
+ switch (src1->get_type ()) {
224
+ case NPU_DATA_TYPE_F32:
225
+ mul_mat_impl<hexagon::vec_dot_product_aligned_f32_f32>(src0, src1, out, params);
226
+ return true ;
227
+
228
+ case NPU_DATA_TYPE_F16:
229
+ mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f16>(src0, src1, out, params);
230
+ return true ;
231
+ default :
232
+ break ;
233
+ }
234
+ } else {
235
+ DEVICE_LOG_DEBUG (" mul_mat_f32: src0 or src1 not aligned\n " );
236
+
237
+ switch (src1->get_type ()) {
238
+ case NPU_DATA_TYPE_F32:
239
+ mul_mat_impl<hexagon::vec_dot_product_f32_f32>(src0, src1, out, params);
240
+ return true ;
241
+
242
+ case NPU_DATA_TYPE_F16:
243
+ mul_mat_impl<hexagon::vec_dot_product_f16_f16>(src0, src1, out, params);
244
+ return true ;
245
+ default :
246
+ break ;
247
+ }
198
248
}
199
249
200
250
DEVICE_LOG_ERROR (" Unsupported src1 tensor type: %s\n " , get_type_name (src1->get_type ()));
0 commit comments