7
7
#include <stdio.h>
8
8
#include <atomic>
9
9
#include <assert.h>
10
+ #include <float.h>
10
11
11
12
#if defined(GGML_USE_HIPBLAS)
12
13
#include <hip/hip_runtime.h>
@@ -4587,20 +4588,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
4587
4588
block_q4_0 * dsti = (block_q4_0 *) cdsti;
4588
4589
4589
4590
float amax = 0.0f;
4590
- float max = 0.0f;
4591
+ float vmax = 0.0f;
4591
4592
4592
4593
for (int j = 0; j < QK4_0; ++j) {
4593
4594
const float v = xi[j];
4594
4595
if (amax < fabsf(v)) {
4595
4596
amax = fabsf(v);
4596
- max = v;
4597
+ vmax = v;
4597
4598
}
4598
4599
}
4599
4600
4600
- const float d = max / -8;
4601
+ const float d = vmax / -8;
4601
4602
const float id = d ? 1.0f/d : 0.0f;
4602
4603
4603
- y[i]. d = d;
4604
+ dsti-> d = d;
4604
4605
4605
4606
for (int j = 0; j < QK4_0/2; ++j) {
4606
4607
const float x0 = xi[0 + j]*id;
@@ -4614,6 +4615,38 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
4614
4615
}
4615
4616
}
4616
4617
4618
+ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
4619
+ const float * xi = (const float *) cxi;
4620
+ block_q4_1 * dsti = (block_q4_1 *) cdsti;
4621
+
4622
+ float vmin = FLT_MAX;
4623
+ float vmax = -FLT_MAX;
4624
+
4625
+ for (int j = 0; j < QK4_1; ++j) {
4626
+ const float v = xi[j];
4627
+
4628
+ if (v < vmin) vmin = v;
4629
+ if (v > vmax) vmax = v;
4630
+ }
4631
+
4632
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
4633
+ const float id = d ? 1.0f/d : 0.0f;
4634
+
4635
+ dsti->dm.x = d;
4636
+ dsti->dm.y = vmin;
4637
+
4638
+ for (int j = 0; j < QK4_1/2; ++j) {
4639
+ const float x0 = (xi[0 + j] - vmin)*id;
4640
+ const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
4641
+
4642
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
4643
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
4644
+
4645
+ dsti->qs[j] = xi0;
4646
+ dsti->qs[j] |= xi1 << 4;
4647
+ }
4648
+ }
4649
+
4617
4650
template <cpy_kernel_t cpy_blck, int qk>
4618
4651
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
4619
4652
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
0 commit comments