@@ -583,7 +583,7 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 bloc
583
583
584
584
typedef struct {
585
585
float d ; // delta
586
- uint8_t qs [QK ]; // nibbles / quants
586
+ int8_t qs [QK ]; // quants
587
587
} block_q8_0 ;
588
588
static_assert (sizeof (block_q8_0 ) == sizeof (float ) + QK , "wrong q8_0 block size/padding" );
589
589
@@ -1060,9 +1060,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1060
1060
1061
1061
for (int l = 0 ; l < QK ; ++ l ) {
1062
1062
const float v = x [i * QK + l ]* id ;
1063
- const uint8_t vi = (int8_t )roundf (v ) + 128 ;
1064
-
1065
- y [i ].qs [l ] = vi ;
1063
+ y [i ].qs [l ] = roundf (v );
1066
1064
}
1067
1065
}
1068
1066
}
@@ -1095,15 +1093,99 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1095
1093
1096
1094
for (int l = 0 ; l < 8 ; l ++ ) {
1097
1095
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1098
- const float32x4_t vf = vaddq_f32 ( v , vdupq_n_f32 ( 128.5f ));
1099
- const int32x4_t vi = vcvtq_s32_f32 (vf );
1096
+ //TODO: rounding
1097
+ const int32x4_t vi = vcvtq_s32_f32 (v );
1100
1098
1101
1099
y [i ].qs [4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
1102
1100
y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1103
1101
y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1104
1102
y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1105
1103
}
1106
1104
}
1105
+ #elif defined(__AVX2__ ) || defined(__AVX__ )
1106
+ for (int i = 0 ; i < nb ; i ++ ) {
1107
+ // Load elements into 4 AVX vectors
1108
+ __m256 v0 = _mm256_loadu_ps ( x );
1109
+ __m256 v1 = _mm256_loadu_ps ( x + 8 );
1110
+ __m256 v2 = _mm256_loadu_ps ( x + 16 );
1111
+ __m256 v3 = _mm256_loadu_ps ( x + 24 );
1112
+ x += 32 ;
1113
+
1114
+ // Compute max(abs(e)) for the block
1115
+ const __m256 signBit = _mm256_set1_ps ( -0.0f );
1116
+ __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
1117
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
1118
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
1119
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
1120
+
1121
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
1122
+ max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
1123
+ max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
1124
+ const float maxScalar = _mm_cvtss_f32 ( max4 );
1125
+
1126
+ // Quantize these floats
1127
+ const float d = maxScalar / 127.f ;
1128
+ y [i ].d = d ;
1129
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f ;
1130
+ const __m256 mul = _mm256_set1_ps ( id );
1131
+
1132
+ // Apply the multiplier
1133
+ v0 = _mm256_mul_ps ( v0 , mul );
1134
+ v1 = _mm256_mul_ps ( v1 , mul );
1135
+ v2 = _mm256_mul_ps ( v2 , mul );
1136
+ v3 = _mm256_mul_ps ( v3 , mul );
1137
+
1138
+ // Round to nearest integer
1139
+ v0 = _mm256_round_ps ( v0 , _MM_ROUND_NEAREST );
1140
+ v1 = _mm256_round_ps ( v1 , _MM_ROUND_NEAREST );
1141
+ v2 = _mm256_round_ps ( v2 , _MM_ROUND_NEAREST );
1142
+ v3 = _mm256_round_ps ( v3 , _MM_ROUND_NEAREST );
1143
+
1144
+ // Convert floats to integers
1145
+ __m256i i0 = _mm256_cvtps_epi32 ( v0 );
1146
+ __m256i i1 = _mm256_cvtps_epi32 ( v1 );
1147
+ __m256i i2 = _mm256_cvtps_epi32 ( v2 );
1148
+ __m256i i3 = _mm256_cvtps_epi32 ( v3 );
1149
+
1150
+ #if defined(__AVX2__ )
1151
+ // Convert int32 to int16
1152
+ i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1153
+ i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1154
+ // Convert int16 to int8
1155
+ i0 = _mm256_packs_epi16 ( i0 , i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1156
+
1157
+ // We got our precious signed bytes, but the order is now wrong
1158
+ // These AVX2 pack instructions process 16-byte pieces independently
1159
+ // The following instruction is fixing the order
1160
+ const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
1161
+ i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
1162
+
1163
+ _mm256_storeu_si256 ((__m256i * )y [i ].qs , i0 );
1164
+ #else
1165
+ // Since we don't have in AVX some necessary functions,
1166
+ // we split the registers in half and call AVX2 analogs from SSE
1167
+ __m128i ni0 = _mm256_castsi256_si128 ( i0 );
1168
+ __m128i ni1 = _mm256_extractf128_si256 ( i0 , 1 );
1169
+ __m128i ni2 = _mm256_castsi256_si128 ( i1 );
1170
+ __m128i ni3 = _mm256_extractf128_si256 ( i1 , 1 );
1171
+ __m128i ni4 = _mm256_castsi256_si128 ( i2 );
1172
+ __m128i ni5 = _mm256_extractf128_si256 ( i2 , 1 );
1173
+ __m128i ni6 = _mm256_castsi256_si128 ( i3 );
1174
+ __m128i ni7 = _mm256_extractf128_si256 ( i3 , 1 );
1175
+
1176
+ // Convert int32 to int16
1177
+ ni0 = _mm_packs_epi32 ( ni0 , ni1 );
1178
+ ni2 = _mm_packs_epi32 ( ni2 , ni3 );
1179
+ ni4 = _mm_packs_epi32 ( ni4 , ni5 );
1180
+ ni6 = _mm_packs_epi32 ( ni6 , ni7 );
1181
+ // Convert int16 to int8
1182
+ ni0 = _mm_packs_epi16 ( ni0 , ni2 );
1183
+ ni4 = _mm_packs_epi16 ( ni4 , ni6 );
1184
+
1185
+ _mm_storeu_si128 ((__m128i * )(y [i ].qs + 0 ), ni0 );
1186
+ _mm_storeu_si128 ((__m128i * )(y [i ].qs + 16 ), ni4 );
1187
+ #endif
1188
+ }
1107
1189
#else
1108
1190
// scalar
1109
1191
quantize_row_q8_0_reference (x , y , k );
@@ -2508,7 +2590,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2508
2590
2509
2591
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2510
2592
const int8x16_t s8b = vdupq_n_s8 (0x8 );
2511
- const uint8x16_t u128b = vdupq_n_u8 (128 );
2512
2593
2513
2594
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2514
2595
const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2526,21 +2607,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2526
2607
const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2527
2608
2528
2609
// load y
2529
- const uint8x16_t v1_0l = vld1q_u8 (y0 -> qs );
2530
- const uint8x16_t v1_0h = vld1q_u8 (y0 -> qs + 16 );
2531
- const uint8x16_t v1_1l = vld1q_u8 (y1 -> qs );
2532
- const uint8x16_t v1_1h = vld1q_u8 (y1 -> qs + 16 );
2610
+ const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2611
+ const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2612
+ const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2613
+ const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2533
2614
2534
2615
// interleave
2535
- const uint8x16_t v1_0lz = vuzp1q_u8 (v1_0l , v1_0h );
2536
- const uint8x16_t v1_0hz = vuzp2q_u8 (v1_0l , v1_0h );
2537
- const uint8x16_t v1_1lz = vuzp1q_u8 (v1_1l , v1_1h );
2538
- const uint8x16_t v1_1hz = vuzp2q_u8 (v1_1l , v1_1h );
2539
-
2540
- const int8x16_t v1_0ls = vreinterpretq_s8_u8 (vsubq_u8 (v1_0lz , u128b ));
2541
- const int8x16_t v1_0hs = vreinterpretq_s8_u8 (vsubq_u8 (v1_0hz , u128b ));
2542
- const int8x16_t v1_1ls = vreinterpretq_s8_u8 (vsubq_u8 (v1_1lz , u128b ));
2543
- const int8x16_t v1_1hs = vreinterpretq_s8_u8 (vsubq_u8 (v1_1hz , u128b ));
2616
+ const int8x16_t v1_0ls = vuzp1q_s8 (v1_0l , v1_0h );
2617
+ const int8x16_t v1_0hs = vuzp2q_s8 (v1_0l , v1_0h );
2618
+ const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2619
+ const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2544
2620
2545
2621
#if defined(__ARM_FEATURE_DOTPROD )
2546
2622
// dot product into int32x4_t
@@ -2578,14 +2654,102 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2578
2654
}
2579
2655
2580
2656
sumf = sum0 + sum1 ;
2657
+ #elif defined(__AVX2__ )
2658
+ // Initialize accumulator with zeros
2659
+ __m256 acc = _mm256_setzero_ps ();
2660
+
2661
+ // Main loop
2662
+ for (int i = 0 ; i < nb ; ++ i ) {
2663
+ /* Compute combined scale for the block */
2664
+ const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2665
+
2666
+ __m256i bx = bytesFromNibbles (x [i ].qs );
2667
+
2668
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2669
+ const __m256i off = _mm256_set1_epi8 ( 8 );
2670
+ bx = _mm256_sub_epi8 ( bx , off );
2671
+
2672
+ __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2673
+
2674
+ // Get absolute values of x vectors
2675
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2676
+
2677
+ // Sign the values of the y vectors
2678
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2679
+
2680
+ // Perform multiplication and create 16-bit values
2681
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2682
+
2683
+ const __m256i ones = _mm256_set1_epi16 (1 );
2684
+ __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2685
+
2686
+ /* Convert to vectore of 8 int32_t to 8 floats */
2687
+ __m256 q = _mm256_cvtepi32_ps ( xy_q );
2688
+
2689
+ /* Multiply q with scale and accumulate */
2690
+ acc = _mm256_fmadd_ps ( d , q , acc );
2691
+ }
2692
+
2693
+ // Return horizontal sum of the acc vector
2694
+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
2695
+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2696
+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2697
+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2698
+
2699
+ sumf = _mm_cvtss_f32 ( res );
2700
+ #elif defined(__AVX__ )
2701
+ // Initialize accumulator with zeros
2702
+ __m256 acc = _mm256_setzero_ps ();
2703
+
2704
+ // Main loop
2705
+ for (int i = 0 ; i < nb ; ++ i ) {
2706
+ // Compute combined scale for the block
2707
+ const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2708
+
2709
+ __m128i i32 [2 ];
2710
+ for (int j = 0 ; j < 2 ; ++ j ) {
2711
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2712
+ __m128i bx = bytesFromNibbles ( x [i ].qs + 8 * j );
2713
+ __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 * j ));
2714
+
2715
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2716
+ const __m128i off = _mm_set1_epi8 ( 8 );
2717
+ bx = _mm_sub_epi8 ( bx , off );
2718
+
2719
+ // Get absolute values of x vectors
2720
+ const __m128i ax = _mm_sign_epi8 (bx , bx );
2721
+
2722
+ // Sign the values of the y vectors
2723
+ const __m128i sy = _mm_sign_epi8 (by , bx );
2724
+
2725
+ // Perform multiplication and create 16-bit values
2726
+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2727
+
2728
+ const __m128i ones = _mm_set1_epi16 (1 );
2729
+ i32 [j ] = _mm_madd_epi16 (ones , dot );
2730
+ }
2731
+
2732
+ // Convert int32_t to float
2733
+ __m256 p = _mm256_cvtepi32_ps ( _mm256_set_m128i ( i32 [0 ], i32 [1 ] ));
2734
+ // Apply the scale, and accumulate
2735
+ acc = _mm256_add_ps (_mm256_mul_ps ( d , p ), acc );
2736
+ }
2737
+
2738
+ // Return horizontal sum of the acc vector
2739
+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
2740
+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2741
+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2742
+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2743
+
2744
+ sumf = _mm_cvtss_f32 ( res );
2581
2745
#else
2582
2746
// scalar
2583
2747
for (int i = 0 ; i < nb ; i ++ ) {
2584
2748
const float d0 = x [i ].d ;
2585
2749
const float d1 = y [i ].d ;
2586
2750
2587
2751
const uint8_t * restrict p0 = x [i ].qs ;
2588
- const uint8_t * restrict p1 = y [i ].qs ;
2752
+ const int8_t * restrict p1 = y [i ].qs ;
2589
2753
2590
2754
int sumi = 0 ;
2591
2755
for (int j = 0 ; j < QK /2 ; j ++ ) {
@@ -2594,10 +2758,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2594
2758
const int i0 = (int8_t ) (v0 & 0xf ) - 8 ;
2595
2759
const int i1 = (int8_t ) (v0 >> 4 ) - 8 ;
2596
2760
2597
- const int i2 = (int ) p1 [2 * j + 0 ] - 128 ;
2598
- const int i3 = (int ) p1 [2 * j + 1 ] - 128 ;
2599
-
2600
- /*printf("dot product: i0=%4d i1=%4d i2=%4d i3=%4d\n", i0, i1, i2, i3);*/
2761
+ const int i2 = p1 [2 * j + 0 ];
2762
+ const int i3 = p1 [2 * j + 1 ];
2601
2763
2602
2764
sumi += i0 * i2 + i1 * i3 ;
2603
2765
}
@@ -9923,7 +10085,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9923
10085
cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ]* (node -> src0 -> ne [0 ]* node -> src0 -> ne [1 ]);
9924
10086
} else
9925
10087
#endif
9926
- cur = GGML_TYPE_SIZE [GGML_TYPE_Q8_0 ]* ggml_nelements (node -> src1 )/GGML_BLCK_SIZE [GGML_TYPE_Q8_0 ];
10088
+ {
10089
+ cur = GGML_TYPE_SIZE [GGML_TYPE_Q8_0 ]* ggml_nelements (node -> src1 )/GGML_BLCK_SIZE [GGML_TYPE_Q8_0 ];
10090
+ }
9927
10091
} else {
9928
10092
GGML_ASSERT (false);
9929
10093
}
0 commit comments