@@ -18,7 +18,7 @@ namespace hnswlib {
1818
1919// Favor using AVX if available.
2020 static float
21- InnerProductSIMD4Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
21+ InnerProductSIMD4ExtAVX (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
2222 float PORTABLE_ALIGN32 TmpRes[8 ];
2323 float *pVect1 = (float *) pVect1v;
2424 float *pVect2 = (float *) pVect2v;
@@ -64,10 +64,12 @@ namespace hnswlib {
6464 return 1 .0f - sum;
6565}
6666
67- #elif defined(USE_SSE)
67+ #endif
68+
69+ #if defined(USE_SSE)
6870
6971 static float
70- InnerProductSIMD4Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
72+ InnerProductSIMD4ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
7173 float PORTABLE_ALIGN32 TmpRes[8 ];
7274 float *pVect1 = (float *) pVect1v;
7375 float *pVect2 = (float *) pVect2v;
@@ -128,7 +130,7 @@ namespace hnswlib {
128130#if defined(USE_AVX512)
129131
130132 static float
131- InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
133+ InnerProductSIMD16ExtAVX512 (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
132134 float PORTABLE_ALIGN64 TmpRes[16 ];
133135 float *pVect1 = (float *) pVect1v;
134136 float *pVect2 = (float *) pVect2v;
@@ -157,10 +159,12 @@ namespace hnswlib {
157159 return 1 .0f - sum;
158160 }
159161
160- #elif defined(USE_AVX)
162+ #endif
163+
164+ #if defined(USE_AVX)
161165
162166 static float
163- InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
167+ InnerProductSIMD16ExtAVX (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
164168 float PORTABLE_ALIGN32 TmpRes[8 ];
165169 float *pVect1 = (float *) pVect1v;
166170 float *pVect2 = (float *) pVect2v;
@@ -195,10 +199,12 @@ namespace hnswlib {
195199 return 1 .0f - sum;
196200 }
197201
198- #elif defined(USE_SSE)
202+ #endif
203+
204+ #if defined(USE_SSE)
199205
200206 static float
201- InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
207+ InnerProductSIMD16ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
202208 float PORTABLE_ALIGN32 TmpRes[8 ];
203209 float *pVect1 = (float *) pVect1v;
204210 float *pVect2 = (float *) pVect2v;
@@ -245,6 +251,41 @@ namespace hnswlib {
245251#endif
246252
247253#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
254+ static float
255+ InnerProductSIMD4Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
256+ DISTFUNC<float > simdfunc_;
257+ #if defined(USE_AVX)
258+ if (AVXCapable ())
259+ simdfunc_ = InnerProductSIMD4ExtAVX;
260+ else
261+ simdfunc_ = InnerProductSIMD4ExtSSE;
262+ #else
263+ simdfunc_ = InnerProductSIMD4ExtSSE;
264+ #endif
265+ return simdfunc_ (pVect1v, pVect2v, qty_ptr);
266+ }
267+
268+ static float
269+ InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
270+ DISTFUNC<float > simdfunc_;
271+ #if defined(USE_AVX512)
272+ if (AVX512Capable ())
273+ simdfunc_ = InnerProductSIMD16ExtAVX512;
274+ else if (AVXCapable ())
275+ simdfunc_ = InnerProductSIMD16ExtAVX;
276+ else
277+ simdfunc_ = InnerProductSIMD16ExtSSE;
278+ #elif defined(USE_AVX)
279+ if (AVXCapable ())
280+ simdfunc_ = InnerProductSIMD16ExtAVX;
281+ else
282+ simdfunc_ = InnerProductSIMD16ExtSSE;
283+ #else
284+ simdfunc_ = InnerProductSIMD16ExtSSE;
285+ #endif
286+ return simdfunc_ (pVect1v, pVect2v, qty_ptr);
287+ }
288+
248289 static float
249290 InnerProductSIMD16ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
250291 size_t qty = *((size_t *) qty_ptr);
0 commit comments