@@ -607,10 +607,11 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
607
607
assert (k % QK == 0 );
608
608
609
609
const int nb = k / QK ;
610
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
610
611
611
- float * restrict pm = (float * ) ( y );
612
- float * restrict pd = (float * ) ( pm + nb );
613
- uint8_t * restrict pb = (uint8_t * ) ( pd + nb );
612
+ uint8_t * restrict pd = (( uint8_t * )y + 0 * bs );
613
+ uint8_t * restrict pm = (( uint8_t * )y + 0 * bs + sizeof ( float ) );
614
+ uint8_t * restrict pb = (( uint8_t * )y + 0 * bs + 2 * sizeof ( float ) );
614
615
615
616
uint8_t pp [QK /2 ];
616
617
@@ -627,8 +628,10 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
627
628
const float d = (max - min ) / ((1 << 4 ) - 1 );
628
629
const float id = d ? 1.0f /d : 0.0f ;
629
630
630
- pm [i ] = min ;
631
- pd [i ] = d ;
631
+ * (float * )pm = min ;
632
+ * (float * )pd = d ;
633
+ pm += bs ;
634
+ pd += bs ;
632
635
633
636
for (int l = 0 ; l < QK ; l += 2 ) {
634
637
const float v0 = (x [i * QK + l + 0 ] - min )* id ;
@@ -643,7 +646,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
643
646
pp [l /2 ] = vi0 | (vi1 << 4 );
644
647
}
645
648
646
- memcpy (pb + i * QK /2 , pp , sizeof (pp ));
649
+ memcpy (pb , pp , sizeof (pp ));
650
+ pb += bs ;
647
651
}
648
652
}
649
653
@@ -687,16 +691,17 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
687
691
assert (k % QK == 0 );
688
692
689
693
const int nb = k / QK ;
694
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
690
695
691
- const float * restrict pm = (const float * ) ( x );
692
- const float * restrict pd = (const float * ) ( pm + nb );
693
- const uint8_t * restrict pb = (const uint8_t * ) ( pd + nb );
696
+ const uint8_t * restrict pd = (( const uint8_t * )x + 0 * bs );
697
+ const uint8_t * restrict pm = (( const uint8_t * )x + 0 * bs + sizeof ( float ) );
698
+ const uint8_t * restrict pb = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
694
699
695
700
for (int i = 0 ; i < nb ; i ++ ) {
696
- const float m = pm [ i ] ;
697
- const float d = pd [ i ] ;
701
+ const float d = * ( const float * ) ( pd + i * bs ) ;
702
+ const float m = * ( const float * ) ( pm + i * bs ) ;
698
703
699
- const uint8_t * restrict pp = pb + i * QK / 2 ;
704
+ const uint8_t * restrict pp = pb + i * bs ;
700
705
701
706
for (int l = 0 ; l < QK ; l += 2 ) {
702
707
const uint8_t vi = pp [l /2 ];
@@ -1553,14 +1558,16 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
1553
1558
inline static void ggml_vec_dot_q4_1 (const int n , float * restrict s , const void * restrict x , const void * restrict y ) {
1554
1559
const int nb = n / QK ;
1555
1560
1556
- const float * restrict pm0 = (const float * ) x ;
1557
- const float * restrict pm1 = (const float * ) y ;
1561
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
1562
+
1563
+ const uint8_t * restrict pd0 = ((const uint8_t * )x + 0 * bs );
1564
+ const uint8_t * restrict pd1 = ((const uint8_t * )y + 0 * bs );
1558
1565
1559
- const float * restrict pd0 = (const float * ) ( pm0 + nb );
1560
- const float * restrict pd1 = (const float * ) ( pm1 + nb );
1566
+ const uint8_t * restrict pm0 = (( const uint8_t * )x + 0 * bs + sizeof ( float ) );
1567
+ const uint8_t * restrict pm1 = (( const uint8_t * )y + 0 * bs + sizeof ( float ) );
1561
1568
1562
- const uint8_t * restrict pb0 = (const uint8_t * ) ( pd0 + nb );
1563
- const uint8_t * restrict pb1 = (const uint8_t * ) ( pd1 + nb );
1569
+ const uint8_t * restrict pb0 = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
1570
+ const uint8_t * restrict pb1 = (( const uint8_t * )y + 0 * bs + 2 * sizeof ( float ) );
1564
1571
1565
1572
float sumf = 0.0 ;
1566
1573
@@ -1573,14 +1580,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
1573
1580
1574
1581
// Main loop
1575
1582
for (int i = 0 ; i < nb ; ++ i ) {
1576
- const float * m0 = (const float * ) (pm0 + i );
1577
- const float * m1 = (const float * ) (pm1 + i );
1583
+ const float * m0 = (const float * ) (pm0 + i * bs );
1584
+ const float * m1 = (const float * ) (pm1 + i * bs );
1578
1585
1579
- const float * d0 = (const float * ) (pd0 + i );
1580
- const float * d1 = (const float * ) (pd1 + i );
1586
+ const float * d0 = (const float * ) (pd0 + i * bs );
1587
+ const float * d1 = (const float * ) (pd1 + i * bs );
1581
1588
1582
- const uint8_t * restrict p0 = pb0 + i * QK / 2 ;
1583
- const uint8_t * restrict p1 = pb1 + i * QK / 2 ;
1589
+ const uint8_t * restrict p0 = pb0 + i * bs ;
1590
+ const uint8_t * restrict p1 = pb1 + i * bs ;
1584
1591
1585
1592
const __m256 d0v = _mm256_broadcast_ss ( d0 );
1586
1593
const __m256 d1v = _mm256_broadcast_ss ( d1 );
@@ -1646,14 +1653,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
1646
1653
#else
1647
1654
// scalar
1648
1655
for (int i = 0 ; i < nb ; i ++ ) {
1649
- const float m0 = pm0 [ i ] ;
1650
- const float m1 = pm1 [ i ] ;
1656
+ const float * m0 = ( const float * ) ( pm0 + i * bs ) ;
1657
+ const float * m1 = ( const float * ) ( pm1 + i * bs ) ;
1651
1658
1652
- const float d0 = pd0 [ i ] ;
1653
- const float d1 = pd1 [ i ] ;
1659
+ const float * d0 = ( const float * ) ( pd0 + i * bs ) ;
1660
+ const float * d1 = ( const float * ) ( pd1 + i * bs ) ;
1654
1661
1655
- const uint8_t * restrict p0 = pb0 + i * QK / 2 ;
1656
- const uint8_t * restrict p1 = pb1 + i * QK / 2 ;
1662
+ const uint8_t * restrict p0 = pb0 + i * bs ;
1663
+ const uint8_t * restrict p1 = pb1 + i * bs ;
1657
1664
1658
1665
for (int j = 0 ; j < QK /2 ; j ++ ) {
1659
1666
const uint8_t v0 = p0 [j ];
0 commit comments