Skip to content

Commit b2acede

Browse files
committed
cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels
1 parent e8457c9 commit b2acede

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

ggml-cuda.cu

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stdio.h>
88
#include <atomic>
99
#include <assert.h>
10+
#include <float.h>
1011

1112
#if defined(GGML_USE_HIPBLAS)
1213
#include <hip/hip_runtime.h>
@@ -4587,20 +4588,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
45874588
block_q4_0 * dsti = (block_q4_0 *) cdsti;
45884589

45894590
float amax = 0.0f;
4590-
float max = 0.0f;
4591+
float vmax = 0.0f;
45914592

45924593
for (int j = 0; j < QK4_0; ++j) {
45934594
const float v = xi[j];
45944595
if (amax < fabsf(v)) {
45954596
amax = fabsf(v);
4596-
max = v;
4597+
vmax = v;
45974598
}
45984599
}
45994600

4600-
const float d = max / -8;
4601+
const float d = vmax / -8;
46014602
const float id = d ? 1.0f/d : 0.0f;
46024603

4603-
y[i].d = d;
4604+
dsti->d = d;
46044605

46054606
for (int j = 0; j < QK4_0/2; ++j) {
46064607
const float x0 = xi[0 + j]*id;
@@ -4614,6 +4615,38 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
46144615
}
46154616
}
46164617

4618+
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
4619+
const float * xi = (const float *) cxi;
4620+
block_q4_1 * dsti = (block_q4_1 *) cdsti;
4621+
4622+
float vmin = FLT_MAX;
4623+
float vmax = -FLT_MAX;
4624+
4625+
for (int j = 0; j < QK4_1; ++j) {
4626+
const float v = xi[j];
4627+
4628+
if (v < vmin) vmin = v;
4629+
if (v > vmax) vmax = v;
4630+
}
4631+
4632+
const float d = (vmax - vmin) / ((1 << 4) - 1);
4633+
const float id = d ? 1.0f/d : 0.0f;
4634+
4635+
dsti->dm.x = d;
4636+
dsti->dm.y = vmin;
4637+
4638+
for (int j = 0; j < QK4_1/2; ++j) {
4639+
const float x0 = (xi[0 + j] - vmin)*id;
4640+
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
4641+
4642+
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
4643+
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
4644+
4645+
dsti->qs[j] = xi0;
4646+
dsti->qs[j] |= xi1 << 4;
4647+
}
4648+
}
4649+
46174650
template <cpy_kernel_t cpy_blck, int qk>
46184651
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
46194652
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,

0 commit comments

Comments
 (0)