Skip to content

Commit 6b220dc

Browse files
committed
Help clang produce fma instructions
1 parent 9d4d14c commit 6b220dc

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

sgemm.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
107107
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
108108
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
109109

110+
////////////////////////////////////////////////////////////////////////////////////////////////////
111+
// VECTORIZED FUSED MULTIPLY ADD
112+
113+
/**
114+
* Computes a * b + c.
115+
*/
116+
template <typename T, typename U>
117+
inline U madd(T a, T b, U c) {
118+
return add(mul(a, b), c);
119+
}
120+
121+
#if defined(__FMA__)
122+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
123+
template <>
124+
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
125+
return _mm256_fmadd_ps(a, b, c);
126+
}
127+
#endif
128+
#if defined(__AVX512F__)
129+
template <>
130+
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
131+
return _mm512_fmadd_ps(a, b, c);
132+
}
133+
#endif
134+
#endif
135+
136+
#if defined(__ARM_FEATURE_FMA)
137+
template <>
138+
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
139+
return vfmaq_f32(c, b, a);
140+
}
141+
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
142+
template <>
143+
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
144+
return vfmaq_f16(c, b, a);
145+
}
146+
#endif
147+
#endif
148+
110149
////////////////////////////////////////////////////////////////////////////////////////////////////
111150
// VECTORIZED HORIZONTAL SUM
112151

@@ -198,21 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
198237
}
199238
#endif // __AVX512F__
200239

201-
////////////////////////////////////////////////////////////////////////////////////////////////////
202-
// ABSTRACTIONS
203-
204-
/**
205-
* Computes a * b + c.
206-
*
207-
* This operation will become fused into a single arithmetic instruction
208-
* if the hardware has support for this feature, e.g. Intel Haswell+ (c.
209-
* 2013), AMD Bulldozer+ (c. 2011), etc.
210-
*/
211-
template <typename T, typename U>
212-
inline U madd(T a, T b, U c) {
213-
return add(mul(a, b), c);
214-
}
215-
216240
////////////////////////////////////////////////////////////////////////////////////////////////////
217241
// FLOATING POINT MATRIX MULTIPLICATION
218242

0 commit comments

Comments
 (0)