Skip to content

Commit 27be1eb

Browse files
committed
opt on aligned address
wip
1 parent e7c9e2a commit 27be1eb

File tree

3 files changed

+100
-13
lines changed

3 files changed

+100
-13
lines changed

ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,39 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n
168168
return true;
169169
}
170170

171+
bool is_mulmat_tensors_aligned(hexagon::tensor * out) {
172+
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "mul_mat_f32 requires max dims 4");
173+
auto * src0 = out->get_src(0);
174+
auto * src1 = out->get_src(1);
175+
176+
if (!hexagon::is_addr_aligned(src0->get_read_buffer()) || src0->get_nb(1) % hexagon::kBytesPerVector ||
177+
!hexagon::is_addr_aligned(src1->get_read_buffer()) || src1->get_nb(1) % hexagon::kBytesPerVector) {
178+
DEVICE_LOG_DEBUG(
179+
"mul_mat_tensors_aligned: src0: %p, src1: %p, src0.nb[1]: %ld, src1.nb[1]: %ld "
180+
"not aligned to %zu\n",
181+
src0->get_read_buffer(), src1->get_read_buffer(), (long) src0->get_nb(1), (long) src1->get_nb(1),
182+
hexagon::kBytesPerVector);
183+
return false;
184+
}
185+
186+
const auto src1_type_size = hexagon::get_type_traits(src1->get_type()).type_size;
187+
if ((src1->get_ne(0) * src1_type_size) % hexagon::kBytesPerVector) {
188+
DEVICE_LOG_DEBUG("mul_mat_tensors_aligned: src1.ne[0]: %ld, src1.type_size: %zu not aligned to %zu\n",
189+
(long) src1->get_ne(0), src1_type_size, hexagon::kBytesPerVector);
190+
return false;
191+
}
192+
193+
const auto & src0_traits = hexagon::get_type_traits(src1->get_type());
194+
const auto src0_type_size = src0_traits.is_quantized ? sizeof(float) : src0_traits.type_size;
195+
if ((src0->get_ne(0) * src0_type_size) % hexagon::kBytesPerVector) {
196+
DEVICE_LOG_DEBUG("mul_mat_tensors_aligned: src0.ne[0]: %ld, src0.type_size: %zu not aligned to %zu\n",
197+
(long) src0->get_ne(0), src0_type_size, hexagon::kBytesPerVector);
198+
return false;
199+
}
200+
201+
return true;
202+
}
203+
171204
} // namespace
172205

173206
namespace hexagon {
@@ -184,17 +217,34 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
184217
return true; // skip if no src
185218
}
186219

187-
// TODO: array?
188-
switch (src1->get_type()) {
189-
case NPU_DATA_TYPE_F32:
190-
mul_mat_impl<hexagon::vec_dot_product_f32_f32>(src0, src1, out, params);
191-
return true;
192-
193-
case NPU_DATA_TYPE_F16:
194-
mul_mat_impl<hexagon::vec_dot_product_f16_f16>(src0, src1, out, params);
195-
return true;
196-
default:
197-
break;
220+
if (is_mulmat_tensors_aligned(out)) {
221+
DEVICE_LOG_DEBUG("mul_mat_f32: src0 and src1 aligned\n");
222+
223+
switch (src1->get_type()) {
224+
case NPU_DATA_TYPE_F32:
225+
mul_mat_impl<hexagon::vec_dot_product_aligned_f32_f32>(src0, src1, out, params);
226+
return true;
227+
228+
case NPU_DATA_TYPE_F16:
229+
mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f16>(src0, src1, out, params);
230+
return true;
231+
default:
232+
break;
233+
}
234+
} else {
235+
DEVICE_LOG_DEBUG("mul_mat_f32: src0 or src1 not aligned\n");
236+
237+
switch (src1->get_type()) {
238+
case NPU_DATA_TYPE_F32:
239+
mul_mat_impl<hexagon::vec_dot_product_f32_f32>(src0, src1, out, params);
240+
return true;
241+
242+
case NPU_DATA_TYPE_F16:
243+
mul_mat_impl<hexagon::vec_dot_product_f16_f16>(src0, src1, out, params);
244+
return true;
245+
default:
246+
break;
247+
}
198248
}
199249

200250
DEVICE_LOG_ERROR("Unsupported src1 tensor type: %s\n", get_type_name(src1->get_type()));

ggml/src/ggml-qnn/npu/device/vec_ops.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,32 @@ inline float vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size
9090
return _ReduceFunc(sum);
9191
}
9292

93+
template <typename _TElem, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
94+
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), float (*_ReduceFunc)(HVX_Vector)>
95+
inline float vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) {
96+
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
97+
98+
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
99+
HVX_Vector * src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
100+
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
101+
HVX_Vector sum0 = Q6_V_vzero();
102+
HVX_Vector sum1 = Q6_V_vzero();
103+
104+
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
105+
HVX_Vector curr0_lo = src0_vec_ptr[0];
106+
HVX_Vector curr0_hi = src0_vec_ptr[1];
107+
HVX_Vector curr1_lo = src1_vec_ptr[0];
108+
HVX_Vector curr1_hi = src1_vec_ptr[1];
109+
src0_vec_ptr += 2;
110+
src1_vec_ptr += 2;
111+
112+
sum0 = _AddFunc(_MpyFunc(curr0_lo, curr1_lo), sum0);
113+
sum1 = _AddFunc(_MpyFunc(curr0_hi, curr1_hi), sum1);
114+
}
115+
116+
return _ReduceFunc(_AddFunc(sum0, sum1));
117+
}
118+
93119
inline HVX_Vector vec_mpy_qf32(HVX_Vector src0, HVX_Vector src1) {
94120
return Q6_Vqf32_vmpy_VsfVsf(src0, src1);
95121
}
@@ -114,10 +140,19 @@ float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t cou
114140
return vec_dot_product_impl<float, vec_mpy_qf32, vec_add_qf32, hexagon::vec_reduction_qf32_f32>(src0, src1, count);
115141
}
116142

117-
// TODO: merge with vec_dot_product_f32_f32?
143+
float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) {
144+
return vec_dot_product_aligned_impl<float, vec_mpy_qf32, vec_add_qf32, hexagon::vec_reduction_qf32_f32>(src0, src1,
145+
count);
146+
}
147+
118148
float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
119149
return vec_dot_product_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, hexagon::vec_reduction_qf16_f32>(
120150
src0, src1, count);
121151
}
122152

153+
float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
154+
return vec_dot_product_aligned_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, hexagon::vec_reduction_qf16_f32>(
155+
src0, src1, count);
156+
}
157+
123158
} // namespace hexagon

ggml/src/ggml-qnn/npu/device/vec_ops.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ inline size_t unaligned_bytes(const void * addr) {
1616
return ((size_t) addr) & kAlignMask;
1717
}
1818

19-
inline bool is_addr_aligned(void * addr) {
19+
inline bool is_addr_aligned(const void * addr) {
2020
return unaligned_bytes(addr) == 0;
2121
}
2222

@@ -275,7 +275,9 @@ inline void vec_mad_f16(const npu_device_fp16_t * src, float scale, npu_device_f
275275
}
276276

277277
float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count);
278+
float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count);
278279

279280
float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count);
281+
float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count);
280282

281283
} // namespace hexagon

0 commit comments

Comments
 (0)