@@ -903,13 +903,17 @@ typedef struct {
903
903
} block_q4_0;
904
904
static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
905
905
906
+ #define Q4_1DM (2.0f/15.0f)
907
+ #define Q4_1MM (2.0f )
908
+ #define Q4_1D(x) ( (((x) & 0x0F)*Q4_1DM) / 15.0f)
909
+ #define Q4_1M(x) (-1.0f + (((x) >> 4)*Q4_1MM) / 15.0f)
910
+
906
911
#define QK4_1 32
907
912
typedef struct {
908
- ggml_fp16_t d; // delta
909
- ggml_fp16_t m; // min
913
+ uint8_t dm; // 4-bit delta + 4-bit min
910
914
uint8_t qs[QK4_1 / 2]; // nibbles / quants
911
915
} block_q4_1;
912
- static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t ) + QK4_1 / 2, "wrong q4_1 block size/padding");
916
+ static_assert(sizeof(block_q4_1) == sizeof(uint8_t ) + QK4_1 / 2, "wrong q4_1 block size/padding");
913
917
914
918
#define QK5_0 32
915
919
typedef struct {
@@ -1013,11 +1017,17 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
1013
1017
if (v > max) max = v;
1014
1018
}
1015
1019
1016
- const float d = (max - min) / ((1 << 4) - 1);
1017
- const float id = d ? 1.0f/d : 0.0f;
1020
+ y[i].dm = (uint8_t)(floorf((15.0f * (min + 1.0f)) / Q4_1MM)) << 4;
1018
1021
1019
- y[i].d = GGML_FP32_TO_FP16(d);
1020
- y[i].m = GGML_FP32_TO_FP16(min);
1022
+ min = Q4_1M(y[i].dm);
1023
+
1024
+ float d = (max - min) / ((1 << 4) - 1);
1025
+
1026
+ y[i].dm |= (uint8_t)(ceilf((15.0f * d) / Q4_1DM));
1027
+
1028
+ d = Q4_1D(y[i].dm);
1029
+
1030
+ const float id = d ? 1.0f/d : 0.0f;
1021
1031
1022
1032
for (int j = 0; j < qk/2; ++j) {
1023
1033
const float x0 = (x[i*qk + 0 + j] - min)*id;
@@ -1570,8 +1580,8 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
1570
1580
const int nb = k / qk;
1571
1581
1572
1582
for (int i = 0; i < nb; i++) {
1573
- const float d = GGML_FP16_TO_FP32 (x[i].d );
1574
- const float m = GGML_FP16_TO_FP32 (x[i].m );
1583
+ const float d = Q4_1D (x[i].dm );
1584
+ const float m = Q4_1M (x[i].dm );
1575
1585
1576
1586
for (int j = 0; j < qk/2; ++j) {
1577
1587
const int x0 = (x[i].qs[j] & 0x0F);
@@ -2671,7 +2681,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2671
2681
const block_q8_1 * restrict y0 = &y[i + 0];
2672
2682
const block_q8_1 * restrict y1 = &y[i + 1];
2673
2683
2674
- summs += GGML_FP16_TO_FP32 (x0->m ) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
2684
+ summs += Q4_1M (x0->dm ) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
2675
2685
2676
2686
const uint8x16_t m4b = vdupq_n_u8(0x0F);
2677
2687
@@ -2695,8 +2705,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2695
2705
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2696
2706
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
2697
2707
2698
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32 (x0->d )*y0->d);
2699
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32 (x1->d )*y1->d);
2708
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), Q4_1D (x0->dm )*y0->d);
2709
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), Q4_1D (x1->dm )*y1->d);
2700
2710
#else
2701
2711
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
2702
2712
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
@@ -2713,8 +2723,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2713
2723
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2714
2724
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2715
2725
2716
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32 (x0->d )*y0->d);
2717
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32 (x1->d )*y1->d);
2726
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), Q4_1D (x0->dm )*y0->d);
2727
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), Q4_1D (x1->dm )*y1->d);
2718
2728
#endif
2719
2729
}
2720
2730
@@ -2727,10 +2737,10 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2727
2737
2728
2738
// Main loop
2729
2739
for (int i = 0; i < nb; ++i) {
2730
- const float d0 = GGML_FP16_TO_FP32 (x[i].d );
2740
+ const float d0 = Q4_1D (x[i].dm );
2731
2741
const float d1 = y[i].d;
2732
2742
2733
- summs += GGML_FP16_TO_FP32 (x[i].m ) * y[i].s;
2743
+ summs += Q4_1M (x[i].dm ) * y[i].s;
2734
2744
2735
2745
const __m256 d0v = _mm256_set1_ps( d0 );
2736
2746
const __m256 d1v = _mm256_set1_ps( d1 );
@@ -2767,7 +2777,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2767
2777
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
2768
2778
}
2769
2779
2770
- sumf += (GGML_FP16_TO_FP32 (x[i].d )*y[i].d)*sumi + GGML_FP16_TO_FP32 (x[i].m )*y[i].s;
2780
+ sumf += (Q4_1D (x[i].dm )*y[i].d)*sumi + Q4_1M (x[i].dm )*y[i].s;
2771
2781
}
2772
2782
2773
2783
*s = sumf;
0 commit comments