@@ -1404,6 +1404,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1404
1404
case GGML_TYPE_Q4_0: {
1405
1405
ggml_compute_forward_get_rows_q4_0x8 (params, dst);
1406
1406
} break ;
1407
+ case GGML_TYPE_Q4_K: {
1408
+ ggml_compute_forward_get_rows_q4_Kx8 (params, dst);
1409
+ } break ;
1407
1410
default :
1408
1411
GGML_ABORT (" fatal error" );
1409
1412
break ;
@@ -1522,6 +1525,131 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1522
1525
}
1523
1526
}
1524
1527
1528
+ static void ggml_compute_forward_get_rows_q4_Kx8 (
1529
+ const ggml_compute_params * params,
1530
+ ggml_tensor * dst) {
1531
+ const ggml_tensor * src0 = dst->src [0 ];
1532
+ const ggml_tensor * src1 = dst->src [1 ];
1533
+
1534
+ GGML_TENSOR_BINARY_OP_LOCALS
1535
+ const int64_t nc = ne00;
1536
+ const int64_t nr = ggml_nelements (src1);
1537
+
1538
+ assert (ne0 == nc);
1539
+ assert (ne02 == ne11);
1540
+ assert (nb00 == ggml_type_size (src0->type ));
1541
+ assert (ggml_nrows (dst) == nr);
1542
+
1543
+ const int ith = params->ith ;
1544
+ const int nth = params->nth ;
1545
+
1546
+ // rows per thread
1547
+ const int dr = (nr + nth - 1 ) / nth;
1548
+
1549
+ // row range for this thread
1550
+ const int ir0 = dr * ith;
1551
+ const int ir1 = MIN (ir0 + dr, nr);
1552
+
1553
+ constexpr int nrows_interleaved = 8 ;
1554
+ const size_t sizeof_one_repacked_block = sizeof (block_q4_Kx8);
1555
+
1556
+ const int num_repacked_blocks_per_row_width = nc / QK_K;
1557
+
1558
+ const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;
1559
+
1560
+ for (int64_t i = ir0; i < ir1; ++i) {
1561
+ const int64_t i12 = i / (ne11 * ne10);
1562
+ const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
1563
+ const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
1564
+ const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row
1565
+
1566
+ GGML_ASSERT (i01 >= 0 && i01 < ne01);
1567
+
1568
+ int row_group_idx = i01 / nrows_interleaved;
1569
+ const int row_idx_in_group = i01 % nrows_interleaved;
1570
+
1571
+ const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;
1572
+
1573
+ // Pointer to the first block_q4_Kx8 of the identified row_group_idx
1574
+ const block_q4_Kx8 * p_first_repacked_block_of_group_x8 = (const block_q4_Kx8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
1575
+
1576
+ dequantize_row_q4_Kx8 (
1577
+ p_first_repacked_block_of_group_x8,
1578
+ (float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
1579
+ }
1580
+ }
1581
+
1582
+ /* *
1583
+ * Dequantizes a single logical row from the repacked q4_Kx8 data format.
1584
+ *
1585
+ * @param p_repacked_blocks Pointer to the start of the 'block_q4_Kx8' structures for the entire row.
1586
+ * @param y Output buffer for the dequantized float values.
1587
+ * @param k Total number of elements (columns) in the logical row.
1588
+ * @param row_idx_in_group The index (0-7) of the logical row to extract from the interleaved data.
1589
+ */
1590
+
1591
+ static void dequantize_row_q4_Kx8 (
1592
+ const void * GGML_RESTRICT p_repacked_blocks,
1593
+ float * GGML_RESTRICT y,
1594
+ int64_t k,
1595
+ int row_idx_in_group) {
1596
+ constexpr int nrows_interleaved = 8 ;
1597
+ assert (k % QK_K == 0 );
1598
+ assert (row_idx_in_group >= 0 && row_idx_in_group < nrows_interleaved);
1599
+
1600
+ const int nb = k / QK_K;
1601
+ const block_q4_Kx8 * blocks = (const block_q4_Kx8 *)p_repacked_blocks;
1602
+
1603
+ for (int i = 0 ; i < nb; i++) {
1604
+ const block_q4_Kx8 * current_block = &blocks[i];
1605
+
1606
+ const float d_super_block = GGML_FP16_TO_FP32 (current_block->d [row_idx_in_group]);
1607
+ const float dmin_super_block = GGML_FP16_TO_FP32 (current_block->dmin [row_idx_in_group]);
1608
+
1609
+ const uint8_t * ptr_qs_base = current_block->qs ;
1610
+ const uint8_t * ptr_repacked_scales = (const uint8_t *)current_block->scales ;
1611
+ int is = 0 , chunk_group_start_idx = 0 ;
1612
+ for (int j = 0 ; j < QK_K; j += 64 ) {
1613
+
1614
+ uint8_t sc1, m1_val, sc2, m2_val;
1615
+ const uint8_t *scales_repacked_data;
1616
+
1617
+ scales_repacked_data = &ptr_repacked_scales[(is + 0 ) * 12 ];
1618
+ get_scale_min_k4 (row_idx_in_group, scales_repacked_data, &sc1, &m1_val);
1619
+
1620
+ scales_repacked_data = &ptr_repacked_scales[(is + 1 ) * 12 ];
1621
+ get_scale_min_k4 (row_idx_in_group, scales_repacked_data, &sc2, &m2_val);
1622
+
1623
+ const float d1 = d_super_block * sc1;
1624
+ const float m1 = dmin_super_block * m1_val;
1625
+ const float d2 = d_super_block * sc2;
1626
+ const float m2 = dmin_super_block * m2_val;
1627
+
1628
+ for (int idx = 0 ; idx < 4 ; idx++) {
1629
+ const uint8_t * ptr_qs_chunk = ptr_qs_base + ((chunk_group_start_idx + idx) * 64 ) + row_idx_in_group * 8 ;
1630
+ for (int l = 0 ; l < 8 ; ++l) *y++ = d1 * (ptr_qs_chunk[l] & 0xF ) - m1; // 16 elements of quants
1631
+ }
1632
+
1633
+ for (int idx = 0 ; idx < 4 ; idx++) {
1634
+ const uint8_t * ptr_qs_chunk = ptr_qs_base + ((chunk_group_start_idx + idx) * 64 ) + row_idx_in_group * 8 ;
1635
+ for (int l = 0 ; l < 8 ; ++l) *y++ = d2 * (ptr_qs_chunk[l] >> 4 ) - m2; // 16 elements of quants
1636
+ }
1637
+ is += 2 ;
1638
+ chunk_group_start_idx += 4 ;
1639
+ }
1640
+ }
1641
+ }
1642
+
1643
+ static inline void get_scale_min_k4 (int j, const uint8_t *GGML_RESTRICT s, uint8_t *GGML_RESTRICT d, uint8_t *GGML_RESTRICT m) {
1644
+ if (j < 4 ) {
1645
+ *d = s[j] & 63 ;
1646
+ *m = s[j + 4 ] & 63 ;
1647
+ } else {
1648
+ *d = (s[j + 4 ] & 0xF ) | ((s[j - 4 ] >> 6 ) << 4 );
1649
+ *m = (s[j + 4 ] >> 4 ) | ((s[j - 0 ] >> 6 ) << 4 );
1650
+ }
1651
+ }
1652
+
1525
1653
int repack (struct ggml_tensor * t, const void * data, size_t data_size) override {
1526
1654
GGML_LOG_DEBUG (" %s: repack tensor %s with %s_%dx%d\n " , __func__, t->name , ggml_type_name (t->type ),
1527
1655
(int ) NB_COLS, (int ) INTER_SIZE);
@@ -1662,7 +1790,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
1662
1790
if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1663
1791
return false ;
1664
1792
}
1665
- if (op->src [0 ]->type == GGML_TYPE_Q4_0) {
1793
+ if (op->src [0 ]->type == GGML_TYPE_Q4_0 || op-> src [ 0 ]-> type == GGML_TYPE_Q4_K ) {
1666
1794
return true ;
1667
1795
}
1668
1796
}
0 commit comments