Skip to content

Commit 70cf05a

Browse files
committed
Add the Get_Rows & Dequantize implementation adapted to work for repacked weights of type q4_K
1 parent 2705c08 commit 70cf05a

File tree

2 files changed

+130
-2
lines changed

2 files changed

+130
-2
lines changed

ggml/src/ggml-cpu/repack.cpp

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
14041404
case GGML_TYPE_Q4_0: {
14051405
ggml_compute_forward_get_rows_q4_0x8(params, dst);
14061406
} break;
1407+
case GGML_TYPE_Q4_K: {
1408+
ggml_compute_forward_get_rows_q4_Kx8(params, dst);
1409+
} break;
14071410
default:
14081411
GGML_ABORT("fatal error");
14091412
break;
@@ -1522,6 +1525,131 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
15221525
}
15231526
}
15241527

1528+
static void ggml_compute_forward_get_rows_q4_Kx8(
1529+
const ggml_compute_params * params,
1530+
ggml_tensor * dst) {
1531+
const ggml_tensor * src0 = dst->src[0];
1532+
const ggml_tensor * src1 = dst->src[1];
1533+
1534+
GGML_TENSOR_BINARY_OP_LOCALS
1535+
const int64_t nc = ne00;
1536+
const int64_t nr = ggml_nelements(src1);
1537+
1538+
assert(ne0 == nc);
1539+
assert(ne02 == ne11);
1540+
assert(nb00 == ggml_type_size(src0->type));
1541+
assert(ggml_nrows(dst) == nr);
1542+
1543+
const int ith = params->ith;
1544+
const int nth = params->nth;
1545+
1546+
// rows per thread
1547+
const int dr = (nr + nth - 1) / nth;
1548+
1549+
// row range for this thread
1550+
const int ir0 = dr * ith;
1551+
const int ir1 = MIN(ir0 + dr, nr);
1552+
1553+
constexpr int nrows_interleaved = 8;
1554+
const size_t sizeof_one_repacked_block = sizeof(block_q4_Kx8);
1555+
1556+
const int num_repacked_blocks_per_row_width = nc / QK_K;
1557+
1558+
const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;
1559+
1560+
for (int64_t i = ir0; i < ir1; ++i) {
1561+
const int64_t i12 = i / (ne11 * ne10);
1562+
const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
1563+
const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
1564+
const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row
1565+
1566+
GGML_ASSERT(i01 >= 0 && i01 < ne01);
1567+
1568+
int row_group_idx = i01 / nrows_interleaved;
1569+
const int row_idx_in_group = i01 % nrows_interleaved;
1570+
1571+
const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;
1572+
1573+
// Pointer to the first block_q4_Kx8 of the identified row_group_idx
1574+
const block_q4_Kx8 * p_first_repacked_block_of_group_x8 = (const block_q4_Kx8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
1575+
1576+
dequantize_row_q4_Kx8(
1577+
p_first_repacked_block_of_group_x8,
1578+
(float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
1579+
}
1580+
}
1581+
1582+
/**
1583+
* Dequantizes a single logical row from the repacked q4_Kx8 data format.
1584+
*
1585+
* @param p_repacked_blocks Pointer to the start of the 'block_q4_Kx8' structures for the entire row.
1586+
* @param y Output buffer for the dequantized float values.
1587+
* @param k Total number of elements (columns) in the logical row.
1588+
* @param row_idx_in_group The index (0-7) of the logical row to extract from the interleaved data.
1589+
*/
1590+
1591+
static void dequantize_row_q4_Kx8(
1592+
const void * GGML_RESTRICT p_repacked_blocks,
1593+
float * GGML_RESTRICT y,
1594+
int64_t k,
1595+
int row_idx_in_group) {
1596+
constexpr int nrows_interleaved = 8;
1597+
assert(k % QK_K == 0);
1598+
assert(row_idx_in_group >= 0 && row_idx_in_group < nrows_interleaved);
1599+
1600+
const int nb = k / QK_K;
1601+
const block_q4_Kx8 * blocks = (const block_q4_Kx8 *)p_repacked_blocks;
1602+
1603+
for (int i = 0; i < nb; i++) {
1604+
const block_q4_Kx8 * current_block = &blocks[i];
1605+
1606+
const float d_super_block = GGML_FP16_TO_FP32(current_block->d[row_idx_in_group]);
1607+
const float dmin_super_block = GGML_FP16_TO_FP32(current_block->dmin[row_idx_in_group]);
1608+
1609+
const uint8_t * ptr_qs_base = current_block->qs;
1610+
const uint8_t * ptr_repacked_scales = (const uint8_t *)current_block->scales;
1611+
int is = 0, chunk_group_start_idx = 0;
1612+
for (int j = 0; j < QK_K; j += 64) {
1613+
1614+
uint8_t sc1, m1_val, sc2, m2_val;
1615+
const uint8_t *scales_repacked_data;
1616+
1617+
scales_repacked_data = &ptr_repacked_scales[(is + 0) * 12];
1618+
get_scale_min_k4(row_idx_in_group, scales_repacked_data, &sc1, &m1_val);
1619+
1620+
scales_repacked_data = &ptr_repacked_scales[(is + 1) * 12];
1621+
get_scale_min_k4(row_idx_in_group, scales_repacked_data, &sc2, &m2_val);
1622+
1623+
const float d1 = d_super_block * sc1;
1624+
const float m1 = dmin_super_block * m1_val;
1625+
const float d2 = d_super_block * sc2;
1626+
const float m2 = dmin_super_block * m2_val;
1627+
1628+
for (int idx = 0; idx < 4; idx++) {
1629+
const uint8_t * ptr_qs_chunk = ptr_qs_base + ((chunk_group_start_idx + idx) * 64) + row_idx_in_group * 8;
1630+
for (int l = 0; l < 8; ++l) *y++ = d1 * (ptr_qs_chunk[l] & 0xF) - m1; // 16 elements of quants
1631+
}
1632+
1633+
for (int idx = 0; idx < 4; idx++) {
1634+
const uint8_t * ptr_qs_chunk = ptr_qs_base + ((chunk_group_start_idx + idx) * 64) + row_idx_in_group * 8;
1635+
for (int l = 0; l < 8; ++l) *y++ = d2 * (ptr_qs_chunk[l] >> 4) - m2; // 16 elements of quants
1636+
}
1637+
is += 2;
1638+
chunk_group_start_idx += 4;
1639+
}
1640+
}
1641+
}
1642+
1643+
static inline void get_scale_min_k4(int j, const uint8_t *GGML_RESTRICT s, uint8_t *GGML_RESTRICT d, uint8_t *GGML_RESTRICT m) {
1644+
if (j < 4) {
1645+
*d = s[j] & 63;
1646+
*m = s[j + 4] & 63;
1647+
} else {
1648+
*d = (s[j + 4] & 0xF) | ((s[j - 4] >> 6) << 4);
1649+
*m = (s[j + 4] >> 4) | ((s[j - 0] >> 6) << 4);
1650+
}
1651+
}
1652+
15251653
int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
15261654
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
15271655
(int) NB_COLS, (int) INTER_SIZE);
@@ -1662,7 +1790,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
16621790
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
16631791
return false;
16641792
}
1665-
if (op->src[0]->type == GGML_TYPE_Q4_0) {
1793+
if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_K) {
16661794
return true;
16671795
}
16681796
}

src/whisper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1454,7 +1454,7 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
14541454
int64_t n_ctx = hparams.n_audio_ctx;
14551455

14561456
switch (op) {
1457-
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT & GGML_OP_GET_ROWS (q4_0)
1457+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT & GGML_OP_GET_ROWS (repacked - q4_0, q4_K)
14581458
case GGML_OP_MUL_MAT: {
14591459
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
14601460
op_tensor = ggml_mul_mat(ctx, w, b);

0 commit comments

Comments
 (0)