Skip to content

Commit f5ed18b

Browse files
committed
CLBlast broadcast support for Llama 2 GQA
Implement broadcasting of src0 into larger src1 for matrix multiplication. This is only tested with Llama 2 inference. Other cases may not be handled properly.
1 parent da04003 commit f5ed18b

File tree

2 files changed

+71
-28
lines changed

2 files changed

+71
-28
lines changed

ggml-opencl.cpp

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,10 +1476,15 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
14761476

14771477
const int64_t ne10 = src1->ne[0];
14781478
const int64_t ne11 = src1->ne[1];
1479+
const int64_t ne12 = src1->ne[2];
1480+
const int64_t ne13 = src1->ne[3];
14791481

14801482
const int nb2 = dst->nb[2];
14811483
const int nb3 = dst->nb[3];
14821484

1485+
const int64_t r2 = ne12 / ne02;
1486+
const int64_t r3 = ne13 / ne03;
1487+
14831488
const float alpha = 1.0f;
14841489
const float beta = 0.0f;
14851490
const int x_ne = ne01 * ne00;
@@ -1498,13 +1503,24 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
14981503
cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
14991504
cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
15001505

1501-
for (int64_t i03 = 0; i03 < ne03; i03++) {
1502-
for (int64_t i02 = 0; i02 < ne02; i02++) {
1506+
int64_t pi02 = -1;
1507+
int64_t pi03 = -1;
1508+
1509+
for (int64_t i13 = 0; i13 < ne13; i13++) {
1510+
int64_t i03 = i13 / r3;
1511+
1512+
for (int64_t i12 = 0; i12 < ne12; i12++) {
1513+
int64_t i02 = i12 / r2;
1514+
15031515
// copy data to device
15041516
if (src0->backend != GGML_BACKEND_GPU) {
1505-
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
1517+
if (i02 != pi02 || i03 != pi03) {
1518+
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
1519+
pi02 = i02;
1520+
pi03 = i03;
1521+
}
15061522
}
1507-
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
1523+
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
15081524

15091525
CL_CHECK(clFinish(queue));
15101526

@@ -1525,7 +1541,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
15251541
}
15261542

15271543
// copy dst to host
1528-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
1544+
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
15291545
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
15301546
}
15311547
}
@@ -1547,6 +1563,8 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
15471563

15481564
const int64_t ne10 = src1->ne[0];
15491565
const int64_t ne11 = src1->ne[1];
1566+
const int64_t ne12 = src1->ne[2];
1567+
const int64_t ne13 = src1->ne[3];
15501568

15511569
const int nb10 = src1->nb[0];
15521570
const int nb11 = src1->nb[1];
@@ -1556,6 +1574,9 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
15561574
const int nb2 = dst->nb[2];
15571575
const int nb3 = dst->nb[3];
15581576

1577+
const int64_t r2 = ne12 / ne02;
1578+
const int64_t r3 = ne13 / ne03;
1579+
15591580
const ggml_fp16_t alpha = ggml_fp32_to_fp16(1.0f);
15601581
const ggml_fp16_t beta = ggml_fp32_to_fp16(0.0f);
15611582
const int x_ne = ne01 * ne00;
@@ -1577,32 +1598,43 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
15771598
bool src1_cont_rows = nb10 == sizeof(float);
15781599
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
15791600

1580-
for (int64_t i03 = 0; i03 < ne03; i03++) {
1581-
for (int64_t i02 = 0; i02 < ne02; i02++) {
1601+
int64_t pi02 = -1;
1602+
int64_t pi03 = -1;
1603+
1604+
for (int64_t i13 = 0; i13 < ne13; i13++) {
1605+
int64_t i03 = i13 / r3;
1606+
1607+
for (int64_t i12 = 0; i12 < ne12; i12++) {
1608+
int64_t i02 = i12 / r2;
1609+
15821610
// copy src0 to device
15831611
if (src0->backend != GGML_BACKEND_GPU) {
1584-
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
1612+
if (i02 != pi02 || i03 != pi03) {
1613+
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
1614+
pi02 = i02;
1615+
pi03 = i03;
1616+
}
15851617
}
15861618

15871619
// convert src1 to fp16
15881620
// TODO: use multiple threads
1589-
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
1590-
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
1621+
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
1622+
char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
15911623
if (src1_cont_rows) {
15921624
if (src1_cont_cols) {
15931625
ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
15941626
}
15951627
else {
1596-
for (int64_t i01 = 0; i01 < ne11; i01++) {
1597-
ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
1628+
for (int64_t i11 = 0; i11 < ne11; i11++) {
1629+
ggml_fp32_to_fp16_row((float *) (src1i + i11*nb11), tmp + i11*ne10, ne10);
15981630
}
15991631
}
16001632
}
16011633
else {
1602-
for (int64_t i01 = 0; i01 < ne11; i01++) {
1603-
for (int64_t i00 = 0; i00 < ne10; i00++) {
1634+
for (int64_t i11 = 0; i11 < ne11; i11++) {
1635+
for (int64_t i10 = 0; i10 < ne10; i10++) {
16041636
// very slow due to no inlining
1605-
tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
1637+
tmp[i11*ne10 + i10] = ggml_fp32_to_fp16(*(float *) (src1i + i11*nb11 + i10*nb10));
16061638
}
16071639
}
16081640
}
@@ -1631,7 +1663,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
16311663
// copy dst to host, then convert to float
16321664
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
16331665

1634-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
1666+
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
16351667

16361668
ggml_fp16_to_fp32_row(tmp, d, d_ne);
16371669
}
@@ -1652,12 +1684,17 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
16521684

16531685
const int64_t ne10 = src1->ne[0];
16541686
const int64_t ne11 = src1->ne[1];
1687+
const int64_t ne12 = src1->ne[2];
1688+
const int64_t ne13 = src1->ne[3];
16551689

16561690
const int nb2 = dst->nb[2];
16571691
const int nb3 = dst->nb[3];
16581692
const ggml_type type = src0->type;
16591693
const bool mul_mat_vec = ne11 == 1;
16601694

1695+
const int64_t r2 = ne12 / ne02;
1696+
const int64_t r3 = ne13 / ne03;
1697+
16611698
const float alpha = 1.0f;
16621699
const float beta = 0.0f;
16631700
const int x_ne = ne01 * ne00;
@@ -1690,12 +1727,23 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
16901727
size_t ev_idx = 0;
16911728
std::vector<cl_event> events;
16921729

1693-
for (int64_t i03 = 0; i03 < ne03; i03++) {
1694-
for (int64_t i02 = 0; i02 < ne02; i02++) {
1730+
int64_t pi02 = -1;
1731+
int64_t pi03 = -1;
1732+
1733+
for (int64_t i13 = 0; i13 < ne13; i13++) {
1734+
int64_t i03 = i13 / r3;
1735+
1736+
for (int64_t i12 = 0; i12 < ne12; i12++) {
1737+
int64_t i02 = i12 / r2;
1738+
16951739
// copy src0 to device if necessary
16961740
if (src0->backend == GGML_BACKEND_CPU) {
1697-
events.emplace_back();
1698-
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
1741+
if (i02 != pi02 || i03 != pi03) {
1742+
events.emplace_back();
1743+
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
1744+
pi02 = i02;
1745+
pi03 = i03;
1746+
}
16991747
} else if (src0->backend == GGML_BACKEND_GPU) {
17001748
d_Q = (cl_mem) src0->extra;
17011749
} else {
@@ -1704,7 +1752,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
17041752
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
17051753
// copy src1 to device
17061754
events.emplace_back();
1707-
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, events.data() + ev_idx++));
1755+
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++));
17081756

17091757
// compute
17101758
const size_t global = ne01 * CL_DMMV_BLOCK_SIZE;
@@ -1725,7 +1773,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
17251773
CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
17261774

17271775
// copy src1 to device
1728-
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
1776+
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
17291777

17301778
events.emplace_back();
17311779

@@ -1749,7 +1797,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
17491797
}
17501798

17511799
// copy dst to host
1752-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
1800+
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
17531801
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL));
17541802
for (auto *event : events) {
17551803
clReleaseEvent(event);

ggml.c

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11265,11 +11265,6 @@ static void ggml_compute_forward_mul_mat(
1126511265

1126611266
#if defined(GGML_USE_CLBLAST)
1126711267
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
11268-
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
11269-
// ref: https://github.com/ggerganov/ggml/pull/224
11270-
GGML_ASSERT(ne02 == ne12);
11271-
GGML_ASSERT(ne03 == ne13);
11272-
1127311268
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
1127411269
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
1127511270
}

0 commit comments

Comments
 (0)