Skip to content

Commit b801080

Browse files
author
Tony Kuo
committed
avx runtime check
1 parent 07c68cc commit b801080

File tree

3 files changed

+159
-15
lines changed

3 files changed

+159
-15
lines changed

hnswlib/hnswlib.h

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,86 @@
3737
#include <iostream>
3838
#include <string.h>
3939

40+
// Adapted from https://github.com/Mysticial/FeatureDetector
41+
#define _XCR_XFEATURE_ENABLED_MASK 0
42+
#ifdef _WIN32
43+
void cpuid(int32_t out[4], int32_t eax, int32_t ecx){
44+
__cpuidex(out, eax, ecx);
45+
}
46+
__int64 xgetbv(unsigned int x){
47+
return _xgetbv(x);
48+
}
49+
#else
50+
#include <cpuid.h>
51+
#include <stdint.h>
52+
void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
53+
__cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
54+
}
55+
56+
uint64_t xgetbv(unsigned int index) {
57+
uint32_t eax, edx;
58+
__asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
59+
return ((uint64_t)edx << 32) | eax;
60+
}
61+
#endif
62+
63+
bool AVXCapable() {
64+
int cpuInfo[4];
65+
66+
// CPU support
67+
cpuid(cpuInfo, 0, 0);
68+
int nIds = cpuInfo[0];
69+
70+
bool HW_AVX = false;
71+
if (nIds >= 0x00000001) {
72+
cpuid(cpuInfo, 0x00000001, 0);
73+
HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
74+
}
75+
76+
// OS support
77+
cpuid(cpuInfo, 1, 0);
78+
79+
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
80+
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
81+
82+
bool avxSupported = false;
83+
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
84+
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
85+
avxSupported = (xcrFeatureMask & 0x6) == 0x6;
86+
}
87+
return avxSupported;
88+
}
89+
90+
bool AVX512Capable() {
91+
if (!AVXCapable()) return false;
92+
93+
int cpuInfo[4];
94+
95+
// CPU support
96+
cpuid(cpuInfo, 0, 0);
97+
int nIds = cpuInfo[0];
98+
99+
bool HW_AVX512F = false; // AVX512 Foundation
100+
if (nIds >= 0x00000007) {
101+
cpuid(cpuInfo, 0x00000007, 0);
102+
HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
103+
}
104+
105+
// OS support
106+
cpuid(cpuInfo, 1, 0);
107+
108+
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
109+
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
110+
111+
bool avxSupported = false;
112+
bool avx512Supported = false;
113+
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
114+
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
115+
avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
116+
}
117+
return avx512Supported;
118+
}
119+
40120
namespace hnswlib {
41121
typedef size_t labeltype;
42122

@@ -108,7 +188,6 @@ namespace hnswlib {
108188

109189
return result;
110190
}
111-
112191
}
113192

114193
#include "space_l2.h"

hnswlib/space_ip.h

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace hnswlib {
1818

1919
// Favor using AVX if available.
2020
static float
21-
InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
21+
InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
2222
float PORTABLE_ALIGN32 TmpRes[8];
2323
float *pVect1 = (float *) pVect1v;
2424
float *pVect2 = (float *) pVect2v;
@@ -64,10 +64,12 @@ namespace hnswlib {
6464
return 1.0f - sum;
6565
}
6666

67-
#elif defined(USE_SSE)
67+
#endif
68+
69+
#if defined(USE_SSE)
6870

6971
static float
70-
InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
72+
InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
7173
float PORTABLE_ALIGN32 TmpRes[8];
7274
float *pVect1 = (float *) pVect1v;
7375
float *pVect2 = (float *) pVect2v;
@@ -128,7 +130,7 @@ namespace hnswlib {
128130
#if defined(USE_AVX512)
129131

130132
static float
131-
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
133+
InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
132134
float PORTABLE_ALIGN64 TmpRes[16];
133135
float *pVect1 = (float *) pVect1v;
134136
float *pVect2 = (float *) pVect2v;
@@ -157,10 +159,12 @@ namespace hnswlib {
157159
return 1.0f - sum;
158160
}
159161

160-
#elif defined(USE_AVX)
162+
#endif
163+
164+
#if defined(USE_AVX)
161165

162166
static float
163-
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
167+
InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
164168
float PORTABLE_ALIGN32 TmpRes[8];
165169
float *pVect1 = (float *) pVect1v;
166170
float *pVect2 = (float *) pVect2v;
@@ -195,10 +199,12 @@ namespace hnswlib {
195199
return 1.0f - sum;
196200
}
197201

198-
#elif defined(USE_SSE)
202+
#endif
203+
204+
#if defined(USE_SSE)
199205

200206
static float
201-
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
207+
InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
202208
float PORTABLE_ALIGN32 TmpRes[8];
203209
float *pVect1 = (float *) pVect1v;
204210
float *pVect2 = (float *) pVect2v;
@@ -245,6 +251,41 @@ namespace hnswlib {
245251
#endif
246252

247253
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
254+
static float
255+
InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
256+
DISTFUNC<float> simdfunc_;
257+
#if defined(USE_AVX)
258+
if (AVXCapable())
259+
simdfunc_ = InnerProductSIMD4ExtAVX;
260+
else
261+
simdfunc_ = InnerProductSIMD4ExtSSE;
262+
#else
263+
simdfunc_ = InnerProductSIMD4ExtSSE;
264+
#endif
265+
return simdfunc_(pVect1v, pVect2v, qty_ptr);
266+
}
267+
268+
static float
269+
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
270+
DISTFUNC<float> simdfunc_;
271+
#if defined(USE_AVX512)
272+
if (AVX512Capable())
273+
simdfunc_ = InnerProductSIMD16ExtAVX512;
274+
else if (AVXCapable())
275+
simdfunc_ = InnerProductSIMD16ExtAVX;
276+
else
277+
simdfunc_ = InnerProductSIMD16ExtSSE;
278+
#elif defined(USE_AVX)
279+
if (AVXCapable())
280+
simdfunc_ = InnerProductSIMD16ExtAVX;
281+
else
282+
simdfunc_ = InnerProductSIMD16ExtSSE;
283+
#else
284+
simdfunc_ = InnerProductSIMD16ExtSSE;
285+
#endif
286+
return simdfunc_(pVect1v, pVect2v, qty_ptr);
287+
}
288+
248289
static float
249290
InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
250291
size_t qty = *((size_t *) qty_ptr);

hnswlib/space_l2.h

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace hnswlib {
2323

2424
// Favor using AVX512 if available.
2525
static float
26-
L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
26+
L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
2727
float *pVect1 = (float *) pVect1v;
2828
float *pVect2 = (float *) pVect2v;
2929
size_t qty = *((size_t *) qty_ptr);
@@ -52,12 +52,13 @@ namespace hnswlib {
5252

5353
return (res);
5454
}
55+
#endif
5556

56-
#elif defined(USE_AVX)
57+
#if defined(USE_AVX)
5758

5859
// Favor using AVX if available.
5960
static float
60-
L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
61+
L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
6162
float *pVect1 = (float *) pVect1v;
6263
float *pVect2 = (float *) pVect2v;
6364
size_t qty = *((size_t *) qty_ptr);
@@ -89,10 +90,12 @@ namespace hnswlib {
8990
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
9091
}
9192

92-
#elif defined(USE_SSE)
93+
#endif
94+
95+
#if defined(USE_SSE)
9396

9497
static float
95-
L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
98+
L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
9699
float *pVect1 = (float *) pVect1v;
97100
float *pVect2 = (float *) pVect2v;
98101
size_t qty = *((size_t *) qty_ptr);
@@ -141,6 +144,27 @@ namespace hnswlib {
141144
#endif
142145

143146
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
147+
static float
148+
L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
149+
DISTFUNC<float> simdfunc_;
150+
#if defined(USE_AVX512)
151+
if (AVX512Capable())
152+
simdfunc_ = L2SqrSIMD16ExtAVX512;
153+
else if (AVXCapable())
154+
simdfunc_ = L2SqrSIMD16ExtAVX;
155+
else
156+
simdfunc_ = L2SqrSIMD16ExtSSE;
157+
#elif defined(USE_AVX)
158+
if (AVXCapable())
159+
simdfunc_ = L2SqrSIMD16ExtAVX;
160+
else
161+
simdfunc_ = L2SqrSIMD16ExtSSE;
162+
#else
163+
simdfunc_ = L2SqrSIMD16ExtSSE;
164+
#endif
165+
return simdfunc_(pVect1v, pVect2v, qty_ptr);
166+
}
167+
144168
static float
145169
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
146170
size_t qty = *((size_t *) qty_ptr);
@@ -156,7 +180,7 @@ namespace hnswlib {
156180
#endif
157181

158182

159-
#ifdef USE_SSE
183+
#if defined(USE_SSE)
160184
static float
161185
L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
162186
float PORTABLE_ALIGN32 TmpRes[8];

0 commit comments

Comments
 (0)