Skip to content

[SYCL] Iq4 nl #6363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4581,6 +4581,26 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr

}

static const __dpct_inline__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};

template<typename dst_t>
static void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2);
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);;
const int tid = item_ct1.get_local_id(2);
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il;
const float d = (float)x[ib].d;
for (int j = 0; j < 4; ++j) {
y[j+0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
}

}

/*
DPCT1110:4: The total declared local variable size in device function
dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
Expand Down Expand Up @@ -7497,6 +7517,45 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
#endif
}


static __dpct_inline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values,
int & val1, int & val2) {

uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
aux32 = q4 & 0x0f0f0f0f;
uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
val1 = v1 | (v2 << 16);
aux32 = (q4 >> 4) & 0x0f0f0f0f;
v1 = values[q8[0]] | (values[q8[1]] << 8);
v2 = values[q8[2]] | (values[q8[3]] << 8);
val2 = v1 | (v2 << 16);
}


static __dpct_inline__ float
vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {

const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
const int ib32 = iqs;
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;

int v1, v2;
int sumi1 = 0, sumi2 = 0;
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
get_int_from_table_16(aux, values, v1, v2);
sumi1 = dpct::dp4a(v1, q8[l+0], sumi1);
sumi2 = dpct::dp4a(v2, q8[l+4], sumi2);
}

const float d = (float)bq->d * bq8_1[ib32].ds[0];
return d * (sumi1 + sumi2);
}

template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
vec_dot_q_mul_mat_sycl_t vec_dot>
Expand Down Expand Up @@ -8353,6 +8412,52 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
}
}

template <int qk, int qi, typename block_q_t, int vdr>
static void mul_mat_vec_q_iq4_nl_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);

if (row >= nrows) {
return;
}

const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;

// partial sum for each thread
float tmp = 0.0f;

const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;

for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index

const int iby = i * (qk/QK8_1); // y block index that aligns with ibx

const int iqs =
vdr *
(item_ct1.get_local_id(2) %
(qi / vdr)); // x block quant index when casting the quants to int

tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
}

// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}

if (item_ct1.get_local_id(2) == 0) {
dst[row] = tmp;
}
}


template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1) {
Expand Down Expand Up @@ -10096,6 +10201,26 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
}
}

template <typename dst_t>
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
dpct::queue_ptr stream) {
const int nb = (k + QK_K - 1) / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});

stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_nl(
vx, y, item_ct1);
});
});
}
}

template <typename src_t, typename dst_t>
static void convert_unary_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int k,
Expand Down Expand Up @@ -10150,6 +10275,8 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
return dequantize_row_iq3_s_sycl;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
default:
Expand Down Expand Up @@ -10194,6 +10321,8 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
return dequantize_row_iq3_s_sycl;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;
default:
Expand Down Expand Up @@ -13612,6 +13741,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_XXS:
return max_compute_capability >= VER_GEN9 ? 128 : 64;
case GGML_TYPE_IQ3_S:
Expand Down