@@ -64,6 +64,114 @@ option(LLAMA_OPENBLAS "llama: use OpenBLAS"
64
64
option (LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE} )
65
65
option (LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE} )
66
66
67
+ INCLUDE (CheckCSourceRuns)
68
+
69
+ SET (AVX_CODE "
70
+ #include <immintrin.h>
71
+ int main()
72
+ {
73
+ __m256 a;
74
+ a = _mm256_set1_ps(0);
75
+ return 0;
76
+ }
77
+ " )
78
+
79
+ SET (AVX512_CODE "
80
+ #include <immintrin.h>
81
+ int main()
82
+ {
83
+ __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
84
+ 0, 0, 0, 0, 0, 0, 0, 0,
85
+ 0, 0, 0, 0, 0, 0, 0, 0,
86
+ 0, 0, 0, 0, 0, 0, 0, 0,
87
+ 0, 0, 0, 0, 0, 0, 0, 0,
88
+ 0, 0, 0, 0, 0, 0, 0, 0,
89
+ 0, 0, 0, 0, 0, 0, 0, 0,
90
+ 0, 0, 0, 0, 0, 0, 0, 0);
91
+ __m512i b = a;
92
+ __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
93
+ return 0;
94
+ }
95
+ " )
96
+
97
+ SET (AVX2_CODE "
98
+ #include <immintrin.h>
99
+ int main()
100
+ {
101
+ __m256i a = {0};
102
+ a = _mm256_abs_epi16(a);
103
+ __m256i x;
104
+ _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
105
+ return 0;
106
+ }
107
+ " )
108
+
109
+ SET (FMA_CODE "
110
+ #include <immintrin.h>
111
+ int main()
112
+ {
113
+ __m256 acc = _mm256_setzero_ps();
114
+ const __m256 d = _mm256_setzero_ps();
115
+ const __m256 p = _mm256_setzero_ps();
116
+ acc = _mm256_fmadd_ps( d, p, acc );
117
+ return 0;
118
+ }
119
+ " )
120
+
121
+ MACRO (CHECK_SSE type flags )
122
+ SET (__FLAG_I 1)
123
+ SET (CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS} )
124
+ FOREACH (__FLAG ${flags} )
125
+ IF (NOT ${type} _FOUND)
126
+ SET (CMAKE_REQUIRED_FLAGS ${__FLAG} )
127
+ CHECK_C_SOURCE_RUNS("${${type} _CODE}" HAS_${type} _${__FLAG_I} )
128
+ IF (HAS_${type} _${__FLAG_I} )
129
+ SET (${type} _FOUND TRUE CACHE BOOL "${type} support" )
130
+ SET (${type} _FLAGS "${__FLAG} " CACHE STRING "${type} flags" )
131
+ ENDIF ()
132
+ MATH (EXPR __FLAG_I "${__FLAG_I} +1" )
133
+ ENDIF ()
134
+ ENDFOREACH ()
135
+ SET (CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE} )
136
+
137
+ IF (NOT ${type} _FOUND)
138
+ SET (${type} _FOUND FALSE CACHE BOOL "${type} support" )
139
+ SET (${type} _FLAGS "" CACHE STRING "${type} flags" )
140
+ ENDIF ()
141
+
142
+ MARK_AS_ADVANCED (${type} _FOUND ${type} _FLAGS)
143
+
144
+ ENDMACRO ()
145
+
146
+ CHECK_SSE("AVX" " ;-mavx;/arch:AVX" )
147
+ CHECK_SSE("AVX2" " ;-mavx2 -mfma;/arch:AVX2" )
148
+ CHECK_SSE("AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512" )
149
+ CHECK_SSE("FMA" " ;-mfma;" )
150
+
151
+ IF (${AVX_FOUND} )
152
+ set (LLAMA_AVX ON )
153
+ ELSE ()
154
+ set (LLAMA_AVX OFF )
155
+ ENDIF ()
156
+
157
+ IF (${FMA_FOUND} )
158
+ set (LLAMA_FMA ON )
159
+ ELSE ()
160
+ set (LLAMA_FMA OFF )
161
+ ENDIF ()
162
+
163
+ IF (${AVX2_FOUND} )
164
+ set (LLAMA_AVX2 ON )
165
+ ELSE ()
166
+ set (LLAMA_AVX2 OFF )
167
+ ENDIF ()
168
+
169
+ IF (${AVX512_FOUND} )
170
+ set (LLAMA_AVX512 ON )
171
+ ELSE ()
172
+ set (LLAMA_AVX512 OFF )
173
+ ENDIF ()
174
+
67
175
#
68
176
# Compile flags
69
177
#
0 commit comments