5
5
6
6
#if defined(GGML_USE_ACCELERATE )
7
7
# include <Accelerate/Accelerate.h>
8
- #elif defined(GGML_USE_BLAS )
9
- # if defined(GGML_BLAS_USE_MKL )
10
- # include <mkl.h>
11
- # else
12
- # include <cblas.h>
13
- # endif
8
+ #elif defined(GGML_BLAS_USE_MKL )
9
+ # include <mkl.h>
10
+ #else
11
+ # include <cblas.h>
14
12
#endif
15
13
16
14
struct ggml_backend_blas_context {
@@ -21,7 +19,7 @@ struct ggml_backend_blas_context {
21
19
22
20
// helper function to determine if it is better to use BLAS or not
23
21
// for large matrices, BLAS is faster
24
- static bool ggml_compute_forward_mul_mat_use_blas (const struct ggml_tensor * dst ) {
22
+ static bool ggml_backend_blas_use_blas (const struct ggml_tensor * dst ) {
25
23
const struct ggml_tensor * src0 = dst -> src [0 ];
26
24
const struct ggml_tensor * src1 = dst -> src [1 ];
27
25
@@ -72,11 +70,8 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st
72
70
const int64_t r2 = ne12 /ne02 ;
73
71
const int64_t r3 = ne13 /ne03 ;
74
72
75
- // nb01 >= nb00 - src0 is not transposed
76
- // compute by src0 rows
77
-
78
73
const int64_t ne_plane = ne01 * ne00 ;
79
- const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne13 * ne12 * ne_plane * sizeof (float );
74
+ const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03 * ne02 * ne_plane * sizeof (float );
80
75
81
76
if (ctx -> work_size < desired_wsize ) {
82
77
free (ctx -> work_data );
@@ -87,21 +82,19 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st
87
82
void * wdata = ctx -> work_data ;
88
83
89
84
// convert src0 to float
90
- if (true) {
91
- if (type != GGML_TYPE_F32 ) {
92
- ggml_to_float_t const to_float = type_traits .to_float ;
85
+ if (type != GGML_TYPE_F32 ) {
86
+ ggml_to_float_t const to_float = type_traits .to_float ;
93
87
94
- for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
95
- for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
96
- const void * x = (char * ) src0 -> data + i02 * nb02 + i03 * nb03 ;
97
- float * const wplane = (float * ) wdata + i03 * ne12 * ne_plane + i02 * ne_plane ;
88
+ for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
89
+ for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
90
+ const void * x = (char * ) src0 -> data + i02 * nb02 + i03 * nb03 ;
91
+ float * const wplane = (float * ) wdata + i03 * ne12 * ne_plane + i02 * ne_plane ;
98
92
99
93
#ifdef GGML_USE_OPENMP
100
94
#pragma omp parallel for num_threads(ctx->n_threads)
101
95
#endif
102
- for (int64_t i01 = 0 ; i01 < ne01 ; i01 ++ ) {
103
- to_float ((const char * ) x + i01 * nb01 , wplane + i01 * ne00 , ne00 );
104
- }
96
+ for (int64_t i01 = 0 ; i01 < ne01 ; i01 ++ ) {
97
+ to_float ((const char * ) x + i01 * nb01 , wplane + i01 * ne00 , ne00 );
105
98
}
106
99
}
107
100
}
@@ -129,6 +122,70 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st
129
122
}
130
123
}
131
124
125
+ static void ggml_backend_blas_out_prod (struct ggml_backend_blas_context * ctx , struct ggml_tensor * dst ) {
126
+ const struct ggml_tensor * src0 = dst -> src [0 ];
127
+ const struct ggml_tensor * src1 = dst -> src [1 ];
128
+
129
+ GGML_TENSOR_BINARY_OP_LOCALS
130
+
131
+ GGML_ASSERT (ne0 == ne00 );
132
+ GGML_ASSERT (ne1 == ne10 );
133
+ GGML_ASSERT (ne2 == ne02 );
134
+ GGML_ASSERT (ne02 == ne12 );
135
+ GGML_ASSERT (ne3 == ne13 );
136
+ GGML_ASSERT (ne03 == ne13 );
137
+
138
+ // we don't support permuted src0 or src1
139
+ GGML_ASSERT (nb00 == sizeof (float ));
140
+
141
+ // dst cannot be transposed or permuted
142
+ GGML_ASSERT (nb0 == sizeof (float ));
143
+ // GGML_ASSERT(nb0 <= nb1);
144
+ // GGML_ASSERT(nb1 <= nb2);
145
+ // GGML_ASSERT(nb2 <= nb3);
146
+
147
+ // nb01 >= nb00 - src0 is not transposed
148
+ // compute by src0 rows
149
+
150
+ // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
151
+ // src0: (k,n)
152
+ // src1: (k,m)
153
+ // dst: (m,n)
154
+ //
155
+ // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
156
+ // Also expressed as (major,minor)
157
+ // a: (m,k): so src1 transposed
158
+ // b: (k,n): so src0
159
+ // c: (m,n)
160
+ //
161
+ // However, if ggml_is_transposed(src1) is true, then
162
+ // src1->data already contains a transposed version, so sgemm mustn't
163
+ // transpose it further.
164
+
165
+ int n = src0 -> ne [0 ];
166
+ int k = src0 -> ne [1 ];
167
+ int m = src1 -> ne [0 ];
168
+
169
+ int transposeA ;
170
+ int lda ;
171
+
172
+ if (!ggml_is_transposed (src1 )) {
173
+ transposeA = CblasTrans ;
174
+ lda = m ;
175
+ } else {
176
+ transposeA = CblasNoTrans ;
177
+ lda = k ;
178
+ }
179
+
180
+ float * a = (float * ) ((char * ) src1 -> data );
181
+ float * b = (float * ) ((char * ) src0 -> data );
182
+ float * c = (float * ) ((char * ) dst -> data );
183
+
184
+ cblas_sgemm (CblasRowMajor , transposeA , CblasNoTrans , m , n , k , 1.0 , a , lda , b , n , 0.0 , c , n );
185
+
186
+ GGML_UNUSED (ctx );
187
+ }
188
+
132
189
// backend interface
133
190
134
191
GGML_CALL static const char * ggml_backend_blas_name (ggml_backend_t backend ) {
@@ -138,6 +195,9 @@ GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
138
195
}
139
196
140
197
GGML_CALL static void ggml_backend_blas_free (ggml_backend_t backend ) {
198
+ struct ggml_backend_blas_context * ctx = (struct ggml_backend_blas_context * )backend -> context ;
199
+ free (ctx -> work_data );
200
+ free (ctx );
141
201
free (backend );
142
202
}
143
203
@@ -158,8 +218,9 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
158
218
ggml_backend_blas_mul_mat (ctx , node );
159
219
break ;
160
220
161
- // TODO
162
- //case GGML_OP_OUT_PROD:
221
+ case GGML_OP_OUT_PROD :
222
+ ggml_backend_blas_out_prod (ctx , node );
223
+ break ;
163
224
164
225
case GGML_OP_NONE :
165
226
case GGML_OP_RESHAPE :
@@ -180,7 +241,16 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
180
241
}
181
242
182
243
GGML_CALL static bool ggml_backend_blas_supports_op (ggml_backend_t backend , const struct ggml_tensor * op ) {
183
- return op -> op == GGML_OP_MUL_MAT && ggml_compute_forward_mul_mat_use_blas (op );
244
+ const struct ggml_tensor * src0 = op -> src [0 ];
245
+ const struct ggml_tensor * src1 = op -> src [1 ];
246
+
247
+ return (op -> op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas (op )) ||
248
+ (op -> op == GGML_OP_OUT_PROD && op -> src [0 ]-> type == GGML_TYPE_F32 &&
249
+ op -> src [1 ]-> type == GGML_TYPE_F32 &&
250
+ ggml_is_matrix (src0 ) &&
251
+ ggml_is_matrix (src1 ) &&
252
+ ggml_is_contiguous (src0 ) &&
253
+ (ggml_is_contiguous (src1 ) || ggml_is_transposed (src1 )));
184
254
185
255
GGML_UNUSED (backend );
186
256
}
@@ -229,9 +299,9 @@ ggml_backend_t ggml_backend_blas_init(void) {
229
299
return NULL ;
230
300
}
231
301
232
- ctx -> n_threads = GGML_DEFAULT_N_THREADS ;
233
- ctx -> work_data = NULL ;
234
- ctx -> work_size = 0 ;
302
+ ctx -> n_threads = GGML_DEFAULT_N_THREADS ;
303
+ ctx -> work_data = NULL ;
304
+ ctx -> work_size = 0 ;
235
305
236
306
* backend = (struct ggml_backend ) {
237
307
/* .guid = */ ggml_backend_blas_guid (),
0 commit comments