Skip to content

Commit f93af6e

Browse files
Signed-off-by: Ahmad Tameem <[email protected]>
Add RISC-V Vector Intrinsics Support Added RVV intrinsics for following ggml_vec_dot_q4_0_q8_0 ggml_vec_dot_q4_1_q8_1 ggml_vec_dot_q5_0_q8_0 ggml_vec_dot_q5_1_q8_1 ggml_vec_dot_q8_0_q8_0 Co-authored-by: Sharafat <[email protected]>
1 parent bc23fcd commit f93af6e

File tree

1 file changed

+227
-0
lines changed

1 file changed

+227
-0
lines changed

ggml.c

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ typedef double ggml_float;
301301
#endif
302302
#endif
303303

304+
#ifdef __riscv_v_intrinsic
305+
#include <riscv_vector.h>
306+
#endif
307+
304308
#ifdef __F16C__
305309

306310
#ifdef _MSC_VER
@@ -2677,6 +2681,41 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
26772681
}
26782682

26792683
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2684+
#elif defined(__riscv_v_intrinsic)
2685+
float sumf = 0.0;
2686+
2687+
size_t vl = __riscv_vsetvl_e8m1(qk/2);
2688+
2689+
for (int i = 0; i < nb; i++) {
2690+
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2691+
2692+
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2693+
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2694+
2695+
vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2696+
vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2697+
2698+
vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2699+
vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2700+
2701+
vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
2702+
vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
2703+
2704+
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2705+
vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2706+
2707+
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2708+
2709+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2710+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2711+
2712+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2713+
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2714+
2715+
sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2716+
}
2717+
2718+
*s = sumf;
26802719
#else
26812720
// scalar
26822721
float sumf = 0.0;
@@ -2803,6 +2842,38 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
28032842
}
28042843

28052844
*s = hsum_float_8(acc) + summs;
2845+
#elif defined(__riscv_v_intrinsic)
2846+
float sumf = 0.0;
2847+
2848+
size_t vl = __riscv_vsetvl_e8m1(qk/2);
2849+
2850+
for (int i = 0; i < nb; i++) {
2851+
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2852+
2853+
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2854+
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2855+
2856+
vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2857+
vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2858+
2859+
vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2860+
vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2861+
2862+
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2863+
vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2864+
2865+
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2866+
2867+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2868+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2869+
2870+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2871+
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2872+
2873+
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2874+
}
2875+
2876+
*s = sumf;
28062877
#else
28072878
// scalar
28082879
float sumf = 0.0;
@@ -3037,6 +3108,76 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
30373108
}
30383109

30393110
*s = hsum_float_8(acc);
3111+
#elif defined(__riscv_v_intrinsic)
3112+
float sumf = 0.0;
3113+
3114+
uint32_t qh;
3115+
3116+
// These temp values are for masking and shift operations
3117+
uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3118+
uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
3119+
0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
3120+
3121+
size_t vl = __riscv_vsetvl_e8m1(qk/2);
3122+
3123+
for (int i = 0; i < nb; i++) {
3124+
memcpy(&qh, x[i].qh, sizeof(uint32_t));
3125+
3126+
// temporary registers
3127+
vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
3128+
vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
3129+
vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
3130+
vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
3131+
3132+
// ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3133+
vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl);
3134+
vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl);
3135+
vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3136+
3137+
// ((qh & (1u << (j + 16))) >> (j + 12));
3138+
vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl);
3139+
vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl);
3140+
3141+
// narrowing
3142+
vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl);
3143+
vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3144+
3145+
vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl);
3146+
vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3147+
3148+
// load
3149+
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3150+
3151+
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3152+
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3153+
3154+
vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3155+
vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3156+
3157+
vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3158+
vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3159+
3160+
vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3161+
vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3162+
3163+
vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl);
3164+
vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl);
3165+
3166+
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3167+
vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3168+
3169+
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3170+
3171+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3172+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3173+
3174+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3175+
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3176+
3177+
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3178+
}
3179+
3180+
*s = sumf;
30403181
#else
30413182
// scalar
30423183
float sumf = 0.0;
@@ -3293,6 +3434,72 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
32933434
}
32943435

32953436
*s = hsum_float_8(acc) + summs;
3437+
#elif defined(__riscv_v_intrinsic)
3438+
float sumf = 0.0;
3439+
3440+
uint32_t qh;
3441+
3442+
// These temp values are for shift operations
3443+
uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3444+
3445+
size_t vl = __riscv_vsetvl_e8m1(qk/2);
3446+
3447+
for (int i = 0; i < nb; i++) {
3448+
memcpy(&qh, x[i].qh, sizeof(uint32_t));
3449+
3450+
// temporary registers
3451+
vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
3452+
vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
3453+
3454+
// load qh
3455+
vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl);
3456+
3457+
// ((qh >> (j + 0)) << 4) & 0x10;
3458+
vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl);
3459+
vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3460+
vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl);
3461+
3462+
// ((qh >> (j + 12)) ) & 0x10;
3463+
vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl);
3464+
vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl);
3465+
3466+
// narrowing
3467+
vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl);
3468+
vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3469+
3470+
vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl);
3471+
vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3472+
3473+
// load
3474+
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3475+
3476+
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3477+
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3478+
3479+
vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3480+
vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3481+
3482+
vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3483+
vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3484+
3485+
vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3486+
vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3487+
3488+
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3489+
vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3490+
3491+
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3492+
3493+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3494+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3495+
3496+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3497+
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3498+
3499+
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
3500+
}
3501+
3502+
*s = sumf;
32963503
#else
32973504
// scalar
32983505
float sumf = 0.0;
@@ -3404,6 +3611,26 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
34043611
}
34053612

34063613
*s = hsum_float_8(acc);
3614+
#elif defined(__riscv_v_intrinsic)
3615+
float sumf = 0.0;
3616+
size_t vl = __riscv_vsetvl_e8m1(qk);
3617+
3618+
for (int i = 0; i < nb; i++) {
3619+
// load elements
3620+
vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
3621+
vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
3622+
3623+
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
3624+
3625+
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
3626+
vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
3627+
3628+
int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
3629+
3630+
sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
3631+
}
3632+
3633+
*s = sumf;
34073634
#else
34083635
// scalar
34093636
float sumf = 0.0;

0 commit comments

Comments
 (0)