diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index fc4d2964ccac9..3758a825b7916 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -4581,6 +4581,26 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr } +static const __dpct_inline__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +template +static void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_group(2); + const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);; + const int tid = item_ct1.get_local_id(2); + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[ib].qs + 4*il; + const float d = (float)x[ib].d; + for (int j = 0; j < 4; ++j) { + y[j+0] = d * kvalues_iq4nl[q4[j] & 0xf]; + y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; + } + +} + /* DPCT1110:4: The total declared local variable size in device function dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register @@ -7497,6 +7517,45 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq, #endif } + +static __dpct_inline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, + int & val1, int & val2) { + + uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; + aux32 = q4 & 0x0f0f0f0f; + uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); + uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); + val1 = v1 | (v2 << 16); + aux32 = (q4 >> 4) & 0x0f0f0f0f; + v1 = values[q8[0]] | (values[q8[1]] << 8); + v2 = values[q8[2]] | (values[q8[3]] << 8); + val2 = v1 | (v2 << 16); +} + + +static __dpct_inline__ float +vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + + const block_iq4_nl * bq = (const block_iq4_nl *) vbq; + const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs; + const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs; + const int ib32 = iqs; + const uint8_t * values = (const uint8_t *)kvalues_iq4nl; + + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) { + const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16); + get_int_from_table_16(aux, values, v1, v2); + sumi1 = dpct::dp4a(v1, q8[l+0], sumi1); + sumi2 = dpct::dp4a(v2, q8[l+4], sumi2); + } + + const float d = (float)bq->d * bq8_1[ib32].ds[0]; + return d * (sumi1 + sumi2); +} + template @@ -8353,6 +8412,52 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * } } +template +static void mul_mat_vec_q_iq4_nl_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + + template static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, const sycl::nd_item<3> &item_ct1) { @@ -10096,6 +10201,26 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k, } } +template +static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k, + dpct::queue_ptr stream) { + const int nb = (k + QK_K - 1) / QK_K; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq4_nl( + vx, y, item_ct1); + }); + }); + } +} + template static void convert_unary_sycl(const void *__restrict__ vx, dst_t *__restrict__ y, const int k, @@ -10150,6 +10275,8 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try { return dequantize_row_iq3_s_sycl; case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_sycl; + case GGML_TYPE_IQ4_NL: + return dequantize_row_iq4_nl_sycl; case GGML_TYPE_F32: return convert_unary_sycl; default: @@ -10194,6 +10321,8 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) { return dequantize_row_iq3_s_sycl; case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_sycl; + case GGML_TYPE_IQ4_NL: + return dequantize_row_iq4_nl_sycl; case GGML_TYPE_F16: return convert_unary_sycl; default: @@ -13612,6 +13741,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array= VER_GEN9 ? 128 : 64; case GGML_TYPE_IQ3_S: