@@ -107,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
107
107
inline float16x8_t mul (float16x8_t x, float16x8_t y) { return vmulq_f16 (x, y); }
108
108
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
109
109
110
+ // //////////////////////////////////////////////////////////////////////////////////////////////////
111
+ // VECTORIZED FUSED MULTIPLY ADD
112
+
113
+ /* *
114
+ * Computes a * b + c.
115
+ */
116
+ template <typename T, typename U>
117
+ inline U madd (T a, T b, U c) {
118
+ return add (mul (a, b), c);
119
+ }
120
+
121
+ #if defined(__FMA__)
122
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
123
+ template <>
124
+ inline __m256 madd (__m256 a, __m256 b, __m256 c) {
125
+ return _mm256_fmadd_ps (a, b, c);
126
+ }
127
+ #endif
128
+ #if defined(__AVX512F__)
129
+ template <>
130
+ inline __m512 madd (__m512 a, __m512 b, __m512 c) {
131
+ return _mm512_fmadd_ps (a, b, c);
132
+ }
133
+ #endif
134
+ #endif
135
+
136
+ #if defined(__ARM_FEATURE_FMA)
137
+ template <>
138
+ inline float32x4_t madd (float32x4_t a, float32x4_t b, float32x4_t c) {
139
+ return vfmaq_f32 (c, b, a);
140
+ }
141
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
142
+ template <>
143
+ inline float16x8_t madd (float16x8_t a, float16x8_t b, float16x8_t c) {
144
+ return vfmaq_f16 (c, b, a);
145
+ }
146
+ #endif
147
+ #endif
148
+
110
149
// //////////////////////////////////////////////////////////////////////////////////////////////////
111
150
// VECTORIZED HORIZONTAL SUM
112
151
@@ -198,21 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
198
237
}
199
238
#endif // __AVX512F__
200
239
201
- // //////////////////////////////////////////////////////////////////////////////////////////////////
202
- // ABSTRACTIONS
203
-
204
- /* *
205
- * Computes a * b + c.
206
- *
207
- * This operation will become fused into a single arithmetic instruction
208
- * if the hardware has support for this feature, e.g. Intel Haswell+ (c.
209
- * 2013), AMD Bulldozer+ (c. 2011), etc.
210
- */
211
- template <typename T, typename U>
212
- inline U madd (T a, T b, U c) {
213
- return add (mul (a, b), c);
214
- }
215
-
216
240
// //////////////////////////////////////////////////////////////////////////////////////////////////
217
241
// FLOATING POINT MATRIX MULTIPLICATION
218
242
0 commit comments