Skip to content

Commit 4e57aa6

Browse files
committed
Add provisions for windows support for BF16 branch including CMake provision for enabling AVX512_BF16
1 parent 82aebcf commit 4e57aa6

File tree

5 files changed

+38
-8
lines changed

5 files changed

+38
-8
lines changed

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ option(LLAMA_AVX2 "llama: enable AVX2"
7777
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
7878
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
7979
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
80+
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
8081
option(LLAMA_FMA "llama: enable FMA" ${INS_ENB})
8182
# in MSVC F16C is implied with AVX2/AVX512
8283
if (NOT MSVC)
@@ -1037,6 +1038,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
10371038
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
10381039
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
10391040
endif()
1041+
if (LLAMA_AVX512_BF16)
1042+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
1043+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
1044+
endif()
10401045
elseif (LLAMA_AVX2)
10411046
list(APPEND ARCH_FLAGS /arch:AVX2)
10421047
elseif (LLAMA_AVX)
@@ -1068,6 +1073,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
10681073
if (LLAMA_AVX512_VNNI)
10691074
list(APPEND ARCH_FLAGS -mavx512vnni)
10701075
endif()
1076+
if (LLAMA_AVX512_BF16)
1077+
add_compile_options(APPEND ARCH_FLAGS -mavx512bf16)
1078+
endif()
10711079
endif()
10721080
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
10731081
message(STATUS "PowerPC detected")

ggml-impl.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@
1717
#define MIN(a, b) ((a) < (b) ? (a) : (b))
1818
#define MAX(a, b) ((a) > (b) ? (a) : (b))
1919

20+
#if defined(_WIN32)
21+
22+
#define m512bh(p) p
23+
#define m512i(p) p
24+
25+
#else
26+
27+
#define m512bh(p) (__m512bh)(p)
28+
#define m512i(p) (__m512i)(p)
29+
30+
#endif
31+
2032
/**
2133
* Converts brain16 to float32.
2234
*

ggml.c

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410410
int i = 0;
411411
#if defined(__AVX512BF16__)
412412
for (; i + 32 <= n; i += 32) {
413-
_mm512_storeu_ps(
414-
(__m512 *)(y + i),
415-
(__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416-
_mm512_loadu_ps(x + i)));
413+
_mm512_storeu_si512(
414+
(__m512i *)(y + i),
415+
m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416+
_mm512_loadu_ps(x + i))));
417417
}
418418
#endif
419419
for (; i < n; i++) {
@@ -1615,10 +1615,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
16151615
__m512 c1 = _mm512_setzero_ps();
16161616
__m512 c2 = _mm512_setzero_ps();
16171617
for (; i + 64 <= n; i += 64) {
1618-
c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1619-
(__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1620-
c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1621-
(__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1618+
c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
1619+
m512bh(_mm512_loadu_si512((y + i))));
1620+
c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
1621+
m512bh(_mm512_loadu_si512((y + i + 32))));
16221622
}
16231623
sumf += (ggml_float)_mm512_reduce_add_ps(c1);
16241624
sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -23028,6 +23028,14 @@ int ggml_cpu_has_avx512_vnni(void) {
2302823028
#endif
2302923029
}
2303023030

23031+
int ggml_cpu_has_avx512_bf16(void) {
23032+
#if defined(__AVX512BF16__)
23033+
return 1;
23034+
#else
23035+
return 0;
23036+
#endif
23037+
}
23038+
2303123039
int ggml_cpu_has_fma(void) {
2303223040
#if defined(__FMA__)
2303323041
return 1;

ggml.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2379,6 +2379,7 @@ extern "C" {
23792379
GGML_API int ggml_cpu_has_avx512 (void);
23802380
GGML_API int ggml_cpu_has_avx512_vbmi(void);
23812381
GGML_API int ggml_cpu_has_avx512_vnni(void);
2382+
GGML_API int ggml_cpu_has_avx512_bf16(void);
23822383
GGML_API int ggml_cpu_has_fma (void);
23832384
GGML_API int ggml_cpu_has_neon (void);
23842385
GGML_API int ggml_cpu_has_arm_fma (void);

llama.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17886,6 +17886,7 @@ const char * llama_print_system_info(void) {
1788617886
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
1788717887
s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
1788817888
s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
17889+
s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
1788917890
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
1789017891
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
1789117892
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";

0 commit comments

Comments
 (0)