@@ -2288,17 +2288,31 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2288
2288
const uint8_t * restrict p0 = x [i ].qs ;
2289
2289
const uint8_t * restrict p1 = y [i ].qs ;
2290
2290
2291
- for (int j = 0 ; j < QK /2 ; j ++ ) {
2292
- const uint8_t v0 = p0 [j ];
2293
- const uint8_t v1 = p1 [j ];
2294
-
2295
- const float f0 = d0 * (v0 & 0xf ) + m0 ;
2296
- const float f1 = d0 * (v0 >> 4 ) + m0 ;
2297
-
2298
- const float f2 = d1 * (v1 & 0xf ) + m1 ;
2299
- const float f3 = d1 * (v1 >> 4 ) + m1 ;
2300
-
2301
- sumf += f0 * f2 + f1 * f3 ;
2291
+ for (int j = 0 ; j < QK /4 ; j ++ ) {
2292
+ const uint32_t v0 = ((uint32_t * )p0 )[j ];
2293
+ const uint32_t v1 = ((uint32_t * )p1 )[j ];
2294
+
2295
+ const uint8_t v0_0 = (v0 >> 0 ) & 0xf ;
2296
+ const uint8_t v0_1 = (v0 >> 4 ) & 0xf ;
2297
+ const uint8_t v0_2 = (v0 >> 8 ) & 0xf ;
2298
+ const uint8_t v0_3 = (v0 >> 12 ) & 0xf ;
2299
+
2300
+ const uint8_t v1_0 = (v1 >> 0 ) & 0xf ;
2301
+ const uint8_t v1_1 = (v1 >> 4 ) & 0xf ;
2302
+ const uint8_t v1_2 = (v1 >> 8 ) & 0xf ;
2303
+ const uint8_t v1_3 = (v1 >> 12 ) & 0xf ;
2304
+
2305
+ const float f0 = d0 * v0_0 + m0 ;
2306
+ const float f1 = d0 * v0_1 + m0 ;
2307
+ const float f2 = d0 * v0_2 + m0 ;
2308
+ const float f3 = d0 * v0_3 + m0 ;
2309
+
2310
+ const float f4 = d1 * v1_0 + m1 ;
2311
+ const float f5 = d1 * v1_1 + m1 ;
2312
+ const float f6 = d1 * v1_2 + m1 ;
2313
+ const float f7 = d1 * v1_3 + m1 ;
2314
+
2315
+ sumf += f0 * f4 + f1 * f5 + f2 * f6 + f3 * f7 ;
2302
2316
}
2303
2317
}
2304
2318
#endif
0 commit comments