Skip to content

Fix Q4_K and Q5_K for QK_K = 64 on CUDA #2359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1566,12 +1566,14 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_K * bq4_K = (const block_q4_K *) vbq;

// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));

float sumf_d = 0.0f;
float sumf_m = 0.0f;

#ifndef GGML_QKK_64

// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));

const float d = bq4_K->d;
const float dmin = bq4_K->dmin;

Expand Down Expand Up @@ -1616,6 +1618,43 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
}

return d*sumf_d - dmin*sumf_m;

#else

uint16_t aux16[2];
const uint8_t * s = (const uint8_t *)aux16;

const uint16_t * a = (const uint16_t *)bq4_K->scales;
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;

const float dall = bq4_K->d[0];
const float dmin = bq4_K->d[1];

const float d8_1 = bq8_1[0].d;
const float d8_2 = bq8_1[1].d;

const int ui1 = *((const int *)bq8_1[0].qs + iqs);
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
const int ui3 = *((const int *)bq8_1[1].qs + iqs);
const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);

const int * q4 = (const int *)bq4_K->qs + iqs;
const int v1 = q4[0];
const int v2 = q4[4];

const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));

sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);

return dall * sumf_d - dmin * sumf_m;

#endif

#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
Expand All @@ -1627,6 +1666,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_K * bq5_K = (const block_q5_K *) vbq;

#ifndef GGML_QKK_64

const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
Expand Down Expand Up @@ -1682,6 +1723,42 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
}

return d*sumf_d - dmin*sumf_m;

#else

const int8_t * s = bq5_K->scales;

const float d = bq5_K->d;

const float d8_1 = bq8_1[0].d;
const float d8_2 = bq8_1[1].d;

const int ui1 = *((const int *)bq8_1[0].qs + iqs);
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
const int ui3 = *((const int *)bq8_1[1].qs + iqs);
const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);

const int * ql = (const int *)bq5_K->qs + iqs;
const int vl1 = ql[0];
const int vl2 = ql[4];

const int step = 4 * iqs; // 0, 4, 8, 12
const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
const int in = step%8; // 0, 4, 0, 4
const int vh = (*((const int *)(bq5_K->qh + in))) >> im;

const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);

const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
+ d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);

return d * sumf_d;

#endif

#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
Expand Down