Skip to content

Commit 77f88e3

Browse files
committed
add support for out_prod
1 parent b88957e commit 77f88e3

File tree

2 files changed

+98
-28
lines changed

2 files changed

+98
-28
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ endif()
9292
# 3rd party libs
9393
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
9494
option(LLAMA_BLAS "llama: use BLAS" OFF)
95-
option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ${LLAMA_LLAMAFILE_DEFAULT})
9695
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
96+
option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ${LLAMA_LLAMAFILE_DEFAULT})
9797
option(LLAMA_CUDA "llama: use CUDA" OFF)
9898
option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF)
9999
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)

ggml-blas.c

Lines changed: 97 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55

66
#if defined(GGML_USE_ACCELERATE)
77
# include <Accelerate/Accelerate.h>
8-
#elif defined(GGML_USE_BLAS)
9-
# if defined(GGML_BLAS_USE_MKL)
10-
# include <mkl.h>
11-
# else
12-
# include <cblas.h>
13-
# endif
8+
#elif defined(GGML_BLAS_USE_MKL)
9+
# include <mkl.h>
10+
#else
11+
# include <cblas.h>
1412
#endif
1513

1614
struct ggml_backend_blas_context {
@@ -21,7 +19,7 @@ struct ggml_backend_blas_context {
2119

2220
// helper function to determine if it is better to use BLAS or not
2321
// for large matrices, BLAS is faster
24-
static bool ggml_compute_forward_mul_mat_use_blas(const struct ggml_tensor * dst) {
22+
static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
2523
const struct ggml_tensor * src0 = dst->src[0];
2624
const struct ggml_tensor * src1 = dst->src[1];
2725

@@ -72,11 +70,8 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st
7270
const int64_t r2 = ne12/ne02;
7371
const int64_t r3 = ne13/ne03;
7472

75-
// nb01 >= nb00 - src0 is not transposed
76-
// compute by src0 rows
77-
7873
const int64_t ne_plane = ne01*ne00;
79-
const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne13*ne12*ne_plane*sizeof(float);
74+
const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
8075

8176
if (ctx->work_size < desired_wsize) {
8277
free(ctx->work_data);
@@ -87,21 +82,19 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st
8782
void * wdata = ctx->work_data;
8883

8984
// convert src0 to float
90-
if (true) {
91-
if (type != GGML_TYPE_F32) {
92-
ggml_to_float_t const to_float = type_traits.to_float;
85+
if (type != GGML_TYPE_F32) {
86+
ggml_to_float_t const to_float = type_traits.to_float;
9387

94-
for (int64_t i03 = 0; i03 < ne03; i03++) {
95-
for (int64_t i02 = 0; i02 < ne02; i02++) {
96-
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
97-
float * const wplane = (float *) wdata + i03*ne12*ne_plane + i02*ne_plane;
88+
for (int64_t i03 = 0; i03 < ne03; i03++) {
89+
for (int64_t i02 = 0; i02 < ne02; i02++) {
90+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
91+
float * const wplane = (float *) wdata + i03*ne12*ne_plane + i02*ne_plane;
9892

9993
#ifdef GGML_USE_OPENMP
10094
#pragma omp parallel for num_threads(ctx->n_threads)
10195
#endif
102-
for (int64_t i01 = 0; i01 < ne01; i01++) {
103-
to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
104-
}
96+
for (int64_t i01 = 0; i01 < ne01; i01++) {
97+
to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
10598
}
10699
}
107100
}
@@ -129,6 +122,70 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st
129122
}
130123
}
131124

125+
static void ggml_backend_blas_out_prod(struct ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
126+
const struct ggml_tensor * src0 = dst->src[0];
127+
const struct ggml_tensor * src1 = dst->src[1];
128+
129+
GGML_TENSOR_BINARY_OP_LOCALS
130+
131+
GGML_ASSERT(ne0 == ne00);
132+
GGML_ASSERT(ne1 == ne10);
133+
GGML_ASSERT(ne2 == ne02);
134+
GGML_ASSERT(ne02 == ne12);
135+
GGML_ASSERT(ne3 == ne13);
136+
GGML_ASSERT(ne03 == ne13);
137+
138+
// we don't support permuted src0 or src1
139+
GGML_ASSERT(nb00 == sizeof(float));
140+
141+
// dst cannot be transposed or permuted
142+
GGML_ASSERT(nb0 == sizeof(float));
143+
// GGML_ASSERT(nb0 <= nb1);
144+
// GGML_ASSERT(nb1 <= nb2);
145+
// GGML_ASSERT(nb2 <= nb3);
146+
147+
// nb01 >= nb00 - src0 is not transposed
148+
// compute by src0 rows
149+
150+
// Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
151+
// src0: (k,n)
152+
// src1: (k,m)
153+
// dst: (m,n)
154+
//
155+
// Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
156+
// Also expressed as (major,minor)
157+
// a: (m,k): so src1 transposed
158+
// b: (k,n): so src0
159+
// c: (m,n)
160+
//
161+
// However, if ggml_is_transposed(src1) is true, then
162+
// src1->data already contains a transposed version, so sgemm mustn't
163+
// transpose it further.
164+
165+
int n = src0->ne[0];
166+
int k = src0->ne[1];
167+
int m = src1->ne[0];
168+
169+
int transposeA;
170+
int lda;
171+
172+
if (!ggml_is_transposed(src1)) {
173+
transposeA = CblasTrans;
174+
lda = m;
175+
} else {
176+
transposeA = CblasNoTrans;
177+
lda = k;
178+
}
179+
180+
float * a = (float *) ((char *) src1->data);
181+
float * b = (float *) ((char *) src0->data);
182+
float * c = (float *) ((char *) dst->data);
183+
184+
cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
185+
186+
GGML_UNUSED(ctx);
187+
}
188+
132189
// backend interface
133190

134191
GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
@@ -138,6 +195,9 @@ GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
138195
}
139196

140197
GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) {
198+
struct ggml_backend_blas_context * ctx = (struct ggml_backend_blas_context *)backend->context;
199+
free(ctx->work_data);
200+
free(ctx);
141201
free(backend);
142202
}
143203

@@ -158,8 +218,9 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
158218
ggml_backend_blas_mul_mat(ctx, node);
159219
break;
160220

161-
// TODO
162-
//case GGML_OP_OUT_PROD:
221+
case GGML_OP_OUT_PROD:
222+
ggml_backend_blas_out_prod(ctx, node);
223+
break;
163224

164225
case GGML_OP_NONE:
165226
case GGML_OP_RESHAPE:
@@ -180,7 +241,16 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
180241
}
181242

182243
GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
183-
return op->op == GGML_OP_MUL_MAT && ggml_compute_forward_mul_mat_use_blas(op);
244+
const struct ggml_tensor * src0 = op->src[0];
245+
const struct ggml_tensor * src1 = op->src[1];
246+
247+
return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) ||
248+
(op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
249+
op->src[1]->type == GGML_TYPE_F32 &&
250+
ggml_is_matrix(src0) &&
251+
ggml_is_matrix(src1) &&
252+
ggml_is_contiguous(src0) &&
253+
(ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
184254

185255
GGML_UNUSED(backend);
186256
}
@@ -229,9 +299,9 @@ ggml_backend_t ggml_backend_blas_init(void) {
229299
return NULL;
230300
}
231301

232-
ctx->n_threads = GGML_DEFAULT_N_THREADS;
233-
ctx->work_data = NULL;
234-
ctx->work_size = 0;
302+
ctx->n_threads = GGML_DEFAULT_N_THREADS;
303+
ctx->work_data = NULL;
304+
ctx->work_size = 0;
235305

236306
*backend = (struct ggml_backend) {
237307
/* .guid = */ ggml_backend_blas_guid(),

0 commit comments

Comments
 (0)