From 04be5b0ba4ee5dba621d9d7c8321921ee7037627 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Mar 2023 18:40:13 +0200 Subject: [PATCH 1/2] Attempt to SIMD-ify dequantize_row_q4_0() for ARM_NEON --- ggml.c | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 10334bd406854..3c2f221f042ba 100644 --- a/ggml.c +++ b/ggml.c @@ -755,7 +755,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs); const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float)); -#if defined(__AVX2__) && QK % 32 == 0 +#if defined(__AVX2__) for (int i = 0; i < nb; i++) { // scale factor const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs)); @@ -788,7 +788,60 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { } } } -//#elif defined(__ARM_NEON) +#elif defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + const float d = *(const float *) (pd + i*bs); + + const uint8_t * restrict pp = pb + i*bs; + + for (int l = 0; l < QK; l += 16) { + // Load 16x4-bit integers into 8x8-bit integers + const uint8x8_t v8 = vld1_u8(pp + l/2); + + // Expand 4-bit nibbles to 8-bit bytes + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v1 = vshr_n_u8(v8, 4); + + /*printf("v0: %4d %4d %4d %4d %4d %4d %4d %4d\n", v0[0], v0[1], v0[2], v0[3], v0[4], v0[5], v0[6], v0[7]);*/ + + // Convert to signed 8-bit integers + const int8x8_t vs_0 = vreinterpret_s8_u8(v0); + const int8x8_t vs_1 = vreinterpret_s8_u8(v1); + + // Subtract 8 from each byte + const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8)); + const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8)); + + /*printf("vb_0: %4d %4d %4d %4d %4d %4d %4d %4d\n", vb_0[0], vb_0[1], vb_0[2], vb_0[3], vb_0[4], vb_0[5], vb_0[6], vb_0[7]);*/ + /*printf("vb_1: %4d %4d %4d %4d %4d %4d %4d %4d\n", vb_1[0], vb_1[1], vb_1[2], vb_1[3], vb_1[4], vb_1[5], vb_1[6], vb_1[7]);*/ + + // Convert to 16-bit integers + const int16x8_t vi_0 = vmovl_s8(vb_0); + const int16x8_t vi_1 = vmovl_s8(vb_1); + + /*printf("vi_0: %4d %4d %4d %4d %4d %4d %4d %4d\n", vi_0[0], vi_0[1], vi_0[2], vi_0[3], vi_0[4], vi_0[5], vi_0[6], vi_0[7]);*/ + + // Convert to 32-bit floats + const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0))); + const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0))); + const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1))); + const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1))); + + /*printf("vf: %4.1f %4.1f %4.1f %4.1f %4.1f %4.1f %4.1f %4.1f\n", vf_0[0], vf_0[1], vf_0[2], vf_0[3], vf_1[0], vf_1[1], vf_1[2], vf_1[3]);*/ + + // Multiply by d + const float32x4_t r0 = vmulq_n_f32(vf_0, d); + const float32x4_t r1 = vmulq_n_f32(vf_1, d); + const float32x4_t r2 = vmulq_n_f32(vf_2, d); + const float32x4_t r3 = vmulq_n_f32(vf_3, d); + + // Store + vst1q_f32(y + i*QK + l + 0, r0); + vst1q_f32(y + i*QK + l + 4, r1); + vst1q_f32(y + i*QK + l + 8, r2); + vst1q_f32(y + i*QK + l + 12, r3); + } + } #else // scalar for (int i = 0; i < nb; i++) { From b83ddbd7683bcda3ca4f7fedd3ce3b2969f59597 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Mar 2023 19:31:23 +0200 Subject: [PATCH 2/2] Fix dequantization - forgot to interleave the quants --- ggml.c | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/ggml.c b/ggml.c index 3c2f221f042ba..291e12a0a2293 100644 --- a/ggml.c +++ b/ggml.c @@ -794,6 +794,8 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { const uint8_t * restrict pp = pb + i*bs; + const float32x4_t vd = vdupq_n_f32(d); + for (int l = 0; l < QK; l += 16) { // Load 16x4-bit integers into 8x8-bit integers const uint8x8_t v8 = vld1_u8(pp + l/2); @@ -802,8 +804,6 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); const uint8x8_t v1 = vshr_n_u8(v8, 4); - /*printf("v0: %4d %4d %4d %4d %4d %4d %4d %4d\n", v0[0], v0[1], v0[2], v0[3], v0[4], v0[5], v0[6], v0[7]);*/ - // Convert to signed 8-bit integers const int8x8_t vs_0 = vreinterpret_s8_u8(v0); const int8x8_t vs_1 = vreinterpret_s8_u8(v1); @@ -812,28 +812,27 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8)); const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8)); - /*printf("vb_0: %4d %4d %4d %4d %4d %4d %4d %4d\n", vb_0[0], vb_0[1], vb_0[2], vb_0[3], vb_0[4], vb_0[5], vb_0[6], vb_0[7]);*/ - /*printf("vb_1: %4d %4d %4d %4d %4d %4d %4d %4d\n", vb_1[0], vb_1[1], vb_1[2], vb_1[3], vb_1[4], vb_1[5], vb_1[6], vb_1[7]);*/ + // Interleave and combine + const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1); + const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1); - // Convert to 16-bit integers - const int16x8_t vi_0 = vmovl_s8(vb_0); - const int16x8_t vi_1 = vmovl_s8(vb_1); + const int8x16_t vq = vcombine_s8(vx_0, vx_1); - /*printf("vi_0: %4d %4d %4d %4d %4d %4d %4d %4d\n", vi_0[0], vi_0[1], vi_0[2], vi_0[3], vi_0[4], vi_0[5], vi_0[6], vi_0[7]);*/ + // convert to 2x int16x8_t + const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq)); + const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq)); - // Convert to 32-bit floats + // convert to 4x float32x4_t const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0))); const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0))); const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1))); const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1))); - /*printf("vf: %4.1f %4.1f %4.1f %4.1f %4.1f %4.1f %4.1f %4.1f\n", vf_0[0], vf_0[1], vf_0[2], vf_0[3], vf_1[0], vf_1[1], vf_1[2], vf_1[3]);*/ - // Multiply by d - const float32x4_t r0 = vmulq_n_f32(vf_0, d); - const float32x4_t r1 = vmulq_n_f32(vf_1, d); - const float32x4_t r2 = vmulq_n_f32(vf_2, d); - const float32x4_t r3 = vmulq_n_f32(vf_3, d); + const float32x4_t r0 = vmulq_f32(vf_0, vd); + const float32x4_t r1 = vmulq_f32(vf_1, vd); + const float32x4_t r2 = vmulq_f32(vf_2, vd); + const float32x4_t r3 = vmulq_f32(vf_3, vd); // Store vst1q_f32(y + i*QK + l + 0, r0);