Skip to content

Commit faaaff5

Browse files
authored
CANN: Support MUL_MAT_ID for q8_0 and q4_0 (#13705)
* [CANN]Support MUL_MAT_ID Q8 && Q4 Signed-off-by: noemotiovon <[email protected]> * codestyle adjustment Signed-off-by: noemotiovon <[email protected]> --------- Signed-off-by: noemotiovon <[email protected]>
1 parent e16c473 commit faaaff5

File tree

2 files changed

+142
-6
lines changed

2 files changed

+142
-6
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 133 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2697,14 +2697,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
26972697
}
26982698
}
26992699

2700-
// GroupedMatmulV2 required tensor_list.size < 128
27012700
size_t GROUP_SIZE = 128;
2702-
std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec;
2703-
std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
2704-
std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
2705-
2706-
// split and call GroupedMatmulV2
2701+
// GroupedMatmulV2 required tensor_list.size < 128
27072702
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
2703+
// split and call GroupedMatmulV2
27082704
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
27092705
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
27102706
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
@@ -2722,13 +2718,144 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
27222718
return;
27232719
}
27242720

2721+
/**
2722+
* @brief Performs expert-specific matrix multiplication (MoE) with
2723+
* quantized precision using the CANN backend.
2724+
*
2725+
* This function executes a matrix multiplication operation tailored for
2726+
* Mixture of Experts (MoE) models, where the input tensor is multiplied
2727+
* with expert-specific quantized weight matrices. It leverages the CANN
2728+
* backend to perform efficient low-precision computations and stores the
2729+
* quantized result in the destination tensor `dst`.
2730+
*
2731+
* Quantization techniques reduce memory footprint and improve performance
2732+
* by using lower-bit representations (e.g., int8) instead of floating-point.
2733+
* This function is designed to work with such formats and may incorporate
2734+
* optimizations like identity-based fast paths or routing masks for sparse
2735+
* expert selection.
2736+
*
2737+
* @param ctx The context for executing CANN backend operations.
2738+
* @param dst The destination tensor where the quantized MoE multiplication result
2739+
* will be stored.
2740+
*
2741+
* @note This function assumes quantized data types and is designed for
2742+
* MoE architectures with potential sparse expert routing.
2743+
*/
2744+
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2745+
// TODO: Use aclnnGroupedMatMul
2746+
//dst [M, K, N, 1]
2747+
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
2748+
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
2749+
ggml_tensor * ids = dst->src[2]; //ids [K, N]
2750+
2751+
GGML_TENSOR_BINARY_OP_LOCALS
2752+
2753+
// copy index from npu to cpu
2754+
int64_t n_as = ne02; // A
2755+
int64_t n_ids = ids->ne[0]; // K
2756+
2757+
std::vector<char> ids_host(ggml_nbytes(ids));
2758+
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
2759+
ACL_MEMCPY_DEVICE_TO_HOST);
2760+
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
2761+
2762+
char * src0_original = (char *) src0->data;
2763+
char * src1_original = (char *) src1->data;
2764+
char * dst_original = (char *) dst->data;
2765+
2766+
ggml_tensor src0_row = *src0;
2767+
ggml_tensor src1_row = *src1;
2768+
ggml_tensor dst_row = *dst;
2769+
2770+
const enum ggml_type type = dst->src[0]->type;
2771+
float weight_elem_size;
2772+
if (type == GGML_TYPE_Q4_0) {
2773+
weight_elem_size = float(sizeof(uint8_t)) / 2;
2774+
} else if (type == GGML_TYPE_Q8_0) {
2775+
weight_elem_size = float(sizeof(uint8_t));
2776+
} else {
2777+
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
2778+
}
2779+
2780+
// src0_row [D, M, 1, 1] weight without permute
2781+
src0_row.ne[2] = 1;
2782+
src0_row.ne[3] = 1;
2783+
src0_row.nb[0] = weight_elem_size;
2784+
src0_row.nb[1] = weight_elem_size * ne00;
2785+
src0_row.nb[2] = weight_elem_size * ne00;
2786+
src0_row.nb[3] = weight_elem_size * ne00;
2787+
size_t weight_stride = ne00 * ne01 * weight_elem_size;
2788+
size_t weight_size = weight_stride * ne02 * ne03;
2789+
2790+
// scale [D, M, 1, 1] -> scale && permute
2791+
size_t scale_elem_size = sizeof(uint16_t);
2792+
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
2793+
2794+
// src1_row [D, 1, 1, 1] -> input
2795+
src1_row.ne[1] = 1;
2796+
src1_row.ne[2] = 1;
2797+
src1_row.ne[3] = 1;
2798+
src1_row.nb[2] = nb11;
2799+
src1_row.nb[3] = nb11;
2800+
2801+
// dst_row [M, 1, 1, 1] -> out
2802+
dst_row.ne[1] = 1;
2803+
dst_row.ne[2] = 1;
2804+
dst_row.ne[3] = 1;
2805+
dst_row.nb[2] = nb1;
2806+
dst_row.nb[3] = nb1;
2807+
2808+
//create weight for one row
2809+
ggml_cann_pool_alloc weight_allocator(ctx.pool());
2810+
void* weight_buffer = weight_allocator.alloc(nb02);
2811+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2812+
for (int64_t id = 0; id < n_ids; id++) {
2813+
// expert index
2814+
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2815+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
2816+
2817+
// If B = 1 (broadcast), always use 0; otherwise, use id.
2818+
int64_t i11 = (ne11 == 1 ? 0 : id);
2819+
int64_t i12 = iid1;
2820+
2821+
int64_t i1 = id;
2822+
int64_t i2 = i12;
2823+
2824+
void* src0_tmp_ptr = src0_original + i02*weight_stride;
2825+
void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
2826+
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2827+
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2828+
2829+
// mem cpy
2830+
ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride,
2831+
ACL_MEMCPY_DEVICE_TO_DEVICE);
2832+
void* scale_buffer = (char*)weight_buffer + weight_stride;
2833+
ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride,
2834+
ACL_MEMCPY_DEVICE_TO_DEVICE);
2835+
2836+
src0_row.data = weight_buffer;
2837+
src1_row.data = src1_tmp_ptr;
2838+
dst_row.data = dst_tmp_ptr;
2839+
dst_row.src[0] = &src0_row;
2840+
dst_row.src[1] = &src1_row;
2841+
2842+
ggml_cann_mul_mat(ctx, &dst_row);
2843+
}
2844+
}
2845+
return;
2846+
}
2847+
27252848
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
27262849
const enum ggml_type type = dst->src[0]->type;
27272850
switch (type) {
27282851
case GGML_TYPE_F32:
27292852
case GGML_TYPE_F16:
27302853
ggml_cann_mul_mat_id_fp(ctx, dst);
27312854
break;
2855+
case GGML_TYPE_Q4_0:
2856+
case GGML_TYPE_Q8_0:
2857+
ggml_cann_mul_mat_id_quant(ctx, dst);
2858+
break;
27322859
default:
27332860
GGML_ABORT("Unsupported type for mul_mat_id");
27342861
break;

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,6 +2035,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
20352035
case GGML_TYPE_F16:
20362036
case GGML_TYPE_F32:
20372037
return true;
2038+
case GGML_TYPE_Q8_0:
2039+
case GGML_TYPE_Q4_0:
2040+
#ifdef ASCEND_310P
2041+
// Q4 && Q8 per group is not suppor on 310p device
2042+
return false;
2043+
#endif
2044+
// only support contiguous for quantized types.
2045+
return ggml_is_contiguous(op->src[0]) &&
2046+
ggml_is_contiguous(op->src[1]);
20382047
default:
20392048
return false;
20402049
}

0 commit comments

Comments
 (0)