Skip to content

Commit 6754255

Browse files
committed
ggml : poc for normalizing weights for better quantization
1 parent 1a94186 commit 6754255

File tree

4 files changed

+180
-81
lines changed

4 files changed

+180
-81
lines changed

ggml-cuda.cu

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,17 @@ typedef void (*ggml_cuda_op_t)(
7474
// QR = QK / number of values before dequantization
7575
// QI = number of 32 bit integers before dequantization
7676

77+
#define Q4_0DM (1.0f/8.0f)
78+
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
79+
7780
#define QK4_0 32
7881
#define QR4_0 2
7982
#define QI4_0 (QK4_0 / (4 * QR4_0))
8083
typedef struct {
81-
half d; // delta
84+
int8_t d; // delta
8285
uint8_t qs[QK4_0 / 2]; // nibbles / quants
8386
} block_q4_0;
84-
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
87+
static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
8588

8689
#define QK4_1 32
8790
#define QR4_1 2
@@ -103,16 +106,20 @@ typedef struct {
103106
} block_q5_0;
104107
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
105108

109+
#define Q5_1DM (2.0f/31.0f)
110+
#define Q5_1MM (2.0f )
111+
#define Q5_1D(x) ( (((x) & 0x0F)*Q5_1DM) / 15.0f)
112+
#define Q5_1M(x) (-1.0f + (((x) >> 4)*Q5_1MM) / 15.0f)
113+
106114
#define QK5_1 32
107115
#define QR5_1 2
108116
#define QI5_1 (QK5_1 / (4 * QR5_1))
109117
typedef struct {
110-
half d; // delta
111-
half m; // min
118+
uint8_t dm; // 4-bit delta + 4-bit min
112119
uint8_t qh[4]; // 5-th bit of quants
113120
uint8_t qs[QK5_1 / 2]; // nibbles / quants
114121
} block_q5_1;
115-
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
122+
static_assert(sizeof(block_q5_1) == sizeof(uint8_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
116123

117124
#define QK8_0 32
118125
#define QR8_0 1
@@ -360,7 +367,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
360367
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
361368
const block_q4_0 * x = (const block_q4_0 *) vx;
362369

363-
const dfloat d = x[ib].d;
370+
const dfloat d = Q4_0D(x[ib].d);
364371

365372
const int vui = x[ib].qs[iqs];
366373

@@ -422,8 +429,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
422429
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
423430
const block_q5_1 * x = (const block_q5_1 *) vx;
424431

425-
const dfloat d = x[ib].d;
426-
const dfloat m = x[ib].m;
432+
const dfloat d = Q5_1D(x[ib].dm);
433+
const dfloat m = Q5_1M(x[ib].dm);
427434

428435
uint32_t qh;
429436
memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -1336,7 +1343,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
13361343
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
13371344
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
13381345

1339-
const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d);
1346+
const float d = Q4_0D(bq4_0->d) * __half2float(bq8_1->d);
13401347

13411348
// subtract 8 from each quantized value
13421349
const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
@@ -1419,14 +1426,15 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
14191426
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
14201427
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
14211428

1429+
// TODO: fix misaligned access
14221430
const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
14231431
const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2);
14241432
const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2);
14251433
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
14261434
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
14271435

1428-
const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
1429-
const float m = bq5_1->m;
1436+
const float d = Q5_1D(bq5_1->dm) * __half2float(bq8_1->d);
1437+
const float m = Q5_1M(bq5_1->dm);
14301438
const float s = bq8_1->s;
14311439

14321440
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits

ggml.c

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -892,12 +892,16 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
892892
#endif
893893
#endif
894894

895+
// we know the values are in the [-1 .. 1] range, so abs(d) cannot be more than 1/8 when using 4 bits
896+
#define Q4_0DM (1.0f/8.0f)
897+
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
898+
895899
#define QK4_0 32
896900
typedef struct {
897-
ggml_fp16_t d; // delta
901+
int8_t d; // delta
898902
uint8_t qs[QK4_0 / 2]; // nibbles / quants
899903
} block_q4_0;
900-
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
904+
static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
901905

902906
#define QK4_1 32
903907
typedef struct {
@@ -915,14 +919,21 @@ typedef struct {
915919
} block_q5_0;
916920
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
917921

922+
// we know the values are in the [-1 .. 1] range, so:
923+
// - d is unsigned 4-bit that represents maximum value of 2.0/31 when using 5 bits
924+
// - m is unsigned 4-bit that represents offset from -1.0 which cannot be more than 2.0
925+
#define Q5_1DM (2.0f/31.0f)
926+
#define Q5_1MM (2.0f )
927+
#define Q5_1D(x) ( (((x) & 0x0F)*Q5_1DM) / 15.0f)
928+
#define Q5_1M(x) (-1.0f + (((x) >> 4)*Q5_1MM) / 15.0f)
929+
918930
#define QK5_1 32
919931
typedef struct {
920-
ggml_fp16_t d; // delta
921-
ggml_fp16_t m; // min
932+
uint8_t dm; // 4-bit delta + 4-bit min
922933
uint8_t qh[4]; // 5-th bit of quants
923934
uint8_t qs[QK5_1 / 2]; // nibbles / quants
924935
} block_q5_1;
925-
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
936+
static_assert(sizeof(block_q5_1) == sizeof(uint8_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
926937

927938
#define QK8_0 32
928939
typedef struct {
@@ -959,10 +970,13 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
959970
}
960971
}
961972

962-
const float d = max / -8;
963-
const float id = d ? 1.0f/d : 0.0f;
973+
float d = max / -8;
964974

965-
y[i].d = GGML_FP32_TO_FP16(d);
975+
y[i].d = (int8_t)(ceilf((127.0f * d) / Q4_0DM));
976+
977+
d = Q4_0D(y[i].d);
978+
979+
const float id = d ? 1.0f/d : 0.0f;
966980

967981
for (int j = 0; j < qk/2; ++j) {
968982
const float x0 = x[i*qk + 0 + j]*id;
@@ -1088,11 +1102,17 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
10881102
if (v > max) max = v;
10891103
}
10901104

1091-
const float d = (max - min) / ((1 << 5) - 1);
1092-
const float id = d ? 1.0f/d : 0.0f;
1105+
y[i].dm = (uint8_t)(floorf((15.0f * (min + 1.0f)) / Q5_1MM)) << 4;
10931106

1094-
y[i].d = GGML_FP32_TO_FP16(d);
1095-
y[i].m = GGML_FP32_TO_FP16(min);
1107+
min = Q5_1M(y[i].dm);
1108+
1109+
float d = (max - min) / ((1 << 5) - 1);
1110+
1111+
y[i].dm |= (uint8_t)(ceilf((15.0f * d) / Q5_1DM));
1112+
1113+
d = Q5_1D(y[i].dm);
1114+
1115+
const float id = d ? 1.0f/d : 0.0f;
10961116

10971117
uint32_t qh = 0;
10981118

@@ -1530,7 +1550,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
15301550
const int nb = k / qk;
15311551

15321552
for (int i = 0; i < nb; i++) {
1533-
const float d = GGML_FP16_TO_FP32(x[i].d);
1553+
const float d = Q4_0D(x[i].d);
15341554

15351555
for (int j = 0; j < qk/2; ++j) {
15361556
const int x0 = (x[i].qs[j] & 0x0F) - 8;
@@ -1597,8 +1617,8 @@ static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict
15971617
const int nb = k / qk;
15981618

15991619
for (int i = 0; i < nb; i++) {
1600-
const float d = GGML_FP16_TO_FP32(x[i].d);
1601-
const float m = GGML_FP16_TO_FP32(x[i].m);
1620+
const float d = Q5_1D(x[i].dm);
1621+
const float m = Q5_1M(x[i].dm);
16021622

16031623
uint32_t qh;
16041624
memcpy(&qh, x[i].qh, sizeof(qh));
@@ -2407,8 +2427,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24072427
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
24082428
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
24092429

2410-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2411-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2430+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), Q4_0D(x0->d)*GGML_FP16_TO_FP32(y0->d));
2431+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), Q4_0D(x1->d)*GGML_FP16_TO_FP32(y1->d));
24122432
#else
24132433
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
24142434
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
@@ -2425,8 +2445,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24252445
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
24262446
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
24272447

2428-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2429-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2448+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), Q4_0D(x0->d)*GGML_FP16_TO_FP32(y0->d));
2449+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), Q4_0D(x1->d)*GGML_FP16_TO_FP32(y1->d));
24302450
#endif
24312451
}
24322452

@@ -2438,7 +2458,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24382458
// Main loop
24392459
for (int i = 0; i < nb; ++i) {
24402460
/* Compute combined scale for the block */
2441-
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2461+
const __m256 d = _mm256_set1_ps( Q4_0D(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
24422462

24432463
__m256i bx = bytes_from_nibbles_32(x[i].qs);
24442464

@@ -2462,7 +2482,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24622482
// Main loop
24632483
for (int i = 0; i < nb; ++i) {
24642484
// Compute combined scale for the block
2465-
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2485+
const __m256 d = _mm256_set1_ps( Q4_0D(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
24662486

24672487
const __m128i lowMask = _mm_set1_epi8(0xF);
24682488
const __m128i off = _mm_set1_epi8(8);
@@ -2504,7 +2524,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25042524
_mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
25052525

25062526
// Compute combined scale for the block 0 and 1
2507-
const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
2527+
const __m128 d_0_1 = _mm_set1_ps( Q4_0D(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
25082528

25092529
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
25102530

@@ -2522,7 +2542,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25222542
_mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
25232543

25242544
// Compute combined scale for the block 2 and 3
2525-
const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
2545+
const __m128 d_2_3 = _mm_set1_ps( Q4_0D(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
25262546

25272547
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
25282548

@@ -2555,7 +2575,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25552575
_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
25562576

25572577
// Compute combined scale for the block 0 and 1
2558-
const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2578+
const __m128 d_0_1 = _mm_set1_ps( Q4_0D(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
25592579

25602580
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
25612581

@@ -2573,7 +2593,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25732593
_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
25742594

25752595
// Compute combined scale for the block 2 and 3
2576-
const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
2596+
const __m128 d_2_3 = _mm_set1_ps( Q4_0D(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
25772597

25782598
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
25792599

@@ -2621,7 +2641,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
26212641
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
26222642
}
26232643

2624-
sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2644+
sumf += sumi*Q4_0D(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
26252645
}
26262646

26272647
*s = sumf;
@@ -3026,8 +3046,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
30263046

30273047
const uint8x16_t m4b = vdupq_n_u8(0x0F);
30283048

3029-
summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
3030-
summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
3049+
summs0 += Q5_1M(x0->dm) * y0->s;
3050+
summs1 += Q5_1M(x1->dm) * y1->s;
30313051

30323052
// extract the 5th bit via lookup table ((b) << 4)
30333053
memcpy(&qh0, x0->qh, sizeof(qh0));
@@ -3072,10 +3092,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
30723092
#if defined(__ARM_FEATURE_DOTPROD)
30733093
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
30743094
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3075-
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3095+
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), Q5_1D(x0->dm)*y0->d);
30763096
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
30773097
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3078-
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
3098+
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), Q5_1D(x1->dm)*y1->d);
30793099
#else
30803100
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
30813101
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@@ -3092,8 +3112,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
30923112
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
30933113
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
30943114

3095-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
3096-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
3115+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), Q5_1D(x0->dm)*y0->d);
3116+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), Q5_1D(x1->dm)*y1->d);
30973117
#endif
30983118
}
30993119

@@ -3111,7 +3131,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
31113131
const block_q5_1 * restrict x0 = &x[i];
31123132
const block_q8_1 * restrict y0 = &y[i];
31133133

3114-
summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
3134+
summs += Q5_1M(x0->dm) * y0->s;
31153135

31163136
const v128_t m4b = wasm_i8x16_splat(0x0F);
31173137

@@ -3158,7 +3178,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
31583178
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
31593179
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
31603180
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
3161-
wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
3181+
wasm_f32x4_splat(Q5_1D(x0->dm) * y0->d)));
31623182
}
31633183

31643184
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
@@ -3171,9 +3191,9 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
31713191

31723192
// Main loop
31733193
for (int i = 0; i < nb; i++) {
3174-
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
3194+
const __m256 dx = _mm256_set1_ps(Q5_1D(x[i].dm));
31753195

3176-
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
3196+
summs += Q5_1M(x[i].dm) * y[i].s;
31773197

31783198
__m256i bx = bytes_from_nibbles_32(x[i].qs);
31793199
__m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -3198,9 +3218,9 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
31983218

31993219
// Main loop
32003220
for (int i = 0; i < nb; i++) {
3201-
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
3221+
const __m256 dx = _mm256_set1_ps(Q5_1D(x[i].dm));
32023222

3203-
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
3223+
summs += Q5_1M(x[i].dm) * y[i].s;
32043224

32053225
__m256i bx = bytes_from_nibbles_32(x[i].qs);
32063226
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -3243,7 +3263,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
32433263
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
32443264
}
32453265

3246-
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
3266+
sumf += (Q5_1D(x[i].dm)*y[i].d)*sumi + Q5_1M(x[i].dm)*y[i].s;
32473267
}
32483268

32493269
*s = sumf;
@@ -5470,7 +5490,7 @@ struct ggml_tensor * ggml_sum_rows(
54705490
}
54715491

54725492
int64_t ne[4] = {1,1,1,1};
5473-
for (int i=1; i<a->n_dims; ++i) {
5493+
for (int i = 1; i < a->n_dims; ++i) {
54745494
ne[i] = a->ne[i];
54755495
}
54765496

ggml.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,9 @@ extern "C" {
281281
GGML_TYPE_Q5_K = 13,
282282
GGML_TYPE_Q6_K = 14,
283283
GGML_TYPE_Q8_K = 15,
284-
GGML_TYPE_I8,
285-
GGML_TYPE_I16,
286-
GGML_TYPE_I32,
284+
GGML_TYPE_I8 = 16,
285+
GGML_TYPE_I16 = 17,
286+
GGML_TYPE_I32 = 18,
287287
GGML_TYPE_COUNT,
288288
};
289289

0 commit comments

Comments
 (0)