@@ -410,10 +410,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410
410
int i = 0;
411
411
#if defined(__AVX512BF16__)
412
412
for (; i + 32 <= n; i += 32) {
413
- _mm512_storeu_ps (
414
- (__m512 *)(y + i),
415
- (__m512) _mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
- _mm512_loadu_ps(x + i)));
413
+ _mm512_storeu_si512 (
414
+ (__m512i *)(y + i),
415
+ m512i( _mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
+ _mm512_loadu_ps(x + i) )));
417
417
}
418
418
#endif
419
419
for (; i < n; i++) {
@@ -1615,10 +1615,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
1615
1615
__m512 c1 = _mm512_setzero_ps();
1616
1616
__m512 c2 = _mm512_setzero_ps();
1617
1617
for (; i + 64 <= n; i += 64) {
1618
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)( x + i)),
1619
- (__m512bh)_mm512_loadu_ps((const float *)( y + i)));
1620
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)( x + i + 32)),
1621
- (__m512bh)_mm512_loadu_ps((const float *)( y + i + 32)));
1618
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512(( x + i) )),
1619
+ m512bh(_mm512_loadu_si512(( y + i) )));
1620
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512(( x + i + 32) )),
1621
+ m512bh(_mm512_loadu_si512(( y + i + 32) )));
1622
1622
}
1623
1623
sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1624
1624
sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -23028,6 +23028,14 @@ int ggml_cpu_has_avx512_vnni(void) {
23028
23028
#endif
23029
23029
}
23030
23030
23031
+ int ggml_cpu_has_avx512_bf16(void) {
23032
+ #if defined(__AVX512BF16__)
23033
+ return 1;
23034
+ #else
23035
+ return 0;
23036
+ #endif
23037
+ }
23038
+
23031
23039
int ggml_cpu_has_fma(void) {
23032
23040
#if defined(__FMA__)
23033
23041
return 1;
0 commit comments