|
| 1 | +include(CheckCSourceRuns) |
| 2 | + |
| 3 | +set(AVX_CODE " |
| 4 | + #include <immintrin.h> |
| 5 | + int main() |
| 6 | + { |
| 7 | + __m256 a; |
| 8 | + a = _mm256_set1_ps(0); |
| 9 | + return 0; |
| 10 | + } |
| 11 | +") |
| 12 | + |
| 13 | +set(AVX512_CODE " |
| 14 | + #include <immintrin.h> |
| 15 | + int main() |
| 16 | + { |
| 17 | + __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, |
| 18 | + 0, 0, 0, 0, 0, 0, 0, 0, |
| 19 | + 0, 0, 0, 0, 0, 0, 0, 0, |
| 20 | + 0, 0, 0, 0, 0, 0, 0, 0, |
| 21 | + 0, 0, 0, 0, 0, 0, 0, 0, |
| 22 | + 0, 0, 0, 0, 0, 0, 0, 0, |
| 23 | + 0, 0, 0, 0, 0, 0, 0, 0, |
| 24 | + 0, 0, 0, 0, 0, 0, 0, 0); |
| 25 | + __m512i b = a; |
| 26 | + __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ); |
| 27 | + return 0; |
| 28 | + } |
| 29 | +") |
| 30 | + |
| 31 | +set(AVX2_CODE " |
| 32 | + #include <immintrin.h> |
| 33 | + int main() |
| 34 | + { |
| 35 | + __m256i a = {0}; |
| 36 | + a = _mm256_abs_epi16(a); |
| 37 | + __m256i x; |
| 38 | + _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code |
| 39 | + return 0; |
| 40 | + } |
| 41 | +") |
| 42 | + |
| 43 | +set(FMA_CODE " |
| 44 | + #include <immintrin.h> |
| 45 | + int main() |
| 46 | + { |
| 47 | + __m256 acc = _mm256_setzero_ps(); |
| 48 | + const __m256 d = _mm256_setzero_ps(); |
| 49 | + const __m256 p = _mm256_setzero_ps(); |
| 50 | + acc = _mm256_fmadd_ps( d, p, acc ); |
| 51 | + return 0; |
| 52 | + } |
| 53 | +") |
| 54 | + |
| 55 | +macro(check_sse type flags) |
| 56 | + set(__FLAG_I 1) |
| 57 | + set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) |
| 58 | + foreach (__FLAG ${flags}) |
| 59 | + if (NOT ${type}_FOUND) |
| 60 | + set(CMAKE_REQUIRED_FLAGS ${__FLAG}) |
| 61 | + check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I}) |
| 62 | + if (HAS_${type}_${__FLAG_I}) |
| 63 | + set(${type}_FOUND TRUE CACHE BOOL "${type} support") |
| 64 | + set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags") |
| 65 | + endif() |
| 66 | + math(EXPR __FLAG_I "${__FLAG_I}+1") |
| 67 | + endif() |
| 68 | + endforeach() |
| 69 | + set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) |
| 70 | + |
| 71 | + if (NOT ${type}_FOUND) |
| 72 | + set(${type}_FOUND FALSE CACHE BOOL "${type} support") |
| 73 | + set(${type}_FLAGS "" CACHE STRING "${type} flags") |
| 74 | + endif() |
| 75 | + |
| 76 | + mark_as_advanced(${type}_FOUND ${type}_FLAGS) |
| 77 | +endmacro() |
| 78 | + |
| 79 | +# flags are for MSVC only! |
| 80 | +check_sse("AVX" " ;/arch:AVX") |
| 81 | +if (NOT ${AVX_FOUND}) |
| 82 | + set(LLAMA_AVX OFF) |
| 83 | +else() |
| 84 | + set(LLAMA_AVX ON) |
| 85 | +endif() |
| 86 | + |
| 87 | +check_sse("AVX2" " ;/arch:AVX2") |
| 88 | +check_sse("FMA" " ;/arch:AVX2") |
| 89 | +if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND})) |
| 90 | + set(LLAMA_AVX2 OFF) |
| 91 | +else() |
| 92 | + set(LLAMA_AVX2 ON) |
| 93 | +endif() |
| 94 | + |
| 95 | +check_sse("AVX512" " ;/arch:AVX512") |
| 96 | +if (NOT ${AVX512_FOUND}) |
| 97 | + set(LLAMA_AVX512 OFF) |
| 98 | +else() |
| 99 | + set(LLAMA_AVX512 ON) |
| 100 | +endif() |
0 commit comments