Skip to content

Commit 0548a41

Browse files
authored
ggml : generalize GGML_OP_CONCAT (#7563)
* ggml : generalize GGML_OP_CONCAT (WIP) ggml-ci * tests : add dim != 2 tests * metal : generalize concat kernel * tests : naming * cuda : generalize concat kernel ggml-ci * sycl : add warning and assert * ggml : fix op params handling * metal : bugfix kernel ggml-ci * ggml : reimplement CPU and Metal * cuda : add asserts ggml-ci * ggml : fix ptrs ggml-ci
1 parent 9335b96 commit 0548a41

File tree

7 files changed

+168
-57
lines changed

7 files changed

+168
-57
lines changed

ggml-cuda/concat.cu

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,68 @@
11
#include "concat.cuh"
22

3-
static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) {
3+
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
44
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
55
if (nidx >= ne0) {
66
return;
77
}
8-
// operation
8+
9+
int offset_dst =
10+
nidx +
11+
blockIdx.y * ne0 +
12+
blockIdx.z * ne0 * gridDim.y;
13+
14+
if (nidx < ne00) { // src0
15+
int offset_src =
16+
nidx +
17+
blockIdx.y * ne00 +
18+
blockIdx.z * ne00 * gridDim.y;
19+
dst[offset_dst] = x[offset_src];
20+
} else {
21+
int offset_src =
22+
(nidx - ne00) +
23+
blockIdx.y * (ne0 - ne00) +
24+
blockIdx.z * (ne0 - ne00) * gridDim.y;
25+
dst[offset_dst] = y[offset_src];
26+
}
27+
}
28+
29+
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
30+
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
31+
if (nidx >= ne0) {
32+
return;
33+
}
34+
35+
int offset_dst =
36+
nidx +
37+
blockIdx.y * ne0 +
38+
blockIdx.z * ne0 * gridDim.y;
39+
40+
if (blockIdx.y < ne01) { // src0
41+
int offset_src =
42+
nidx +
43+
blockIdx.y * ne0 +
44+
blockIdx.z * ne0 * ne01;
45+
dst[offset_dst] = x[offset_src];
46+
} else {
47+
int offset_src =
48+
nidx +
49+
(blockIdx.y - ne01) * ne0 +
50+
blockIdx.z * ne0 * (gridDim.y - ne01);
51+
dst[offset_dst] = y[offset_src];
52+
}
53+
}
54+
55+
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
56+
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
57+
if (nidx >= ne0) {
58+
return;
59+
}
60+
961
int offset_dst =
1062
nidx +
1163
blockIdx.y * ne0 +
1264
blockIdx.z * ne0 * gridDim.y;
65+
1366
if (blockIdx.z < ne02) { // src0
1467
int offset_src =
1568
nidx +
@@ -25,25 +78,53 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
2578
}
2679
}
2780

28-
static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
81+
static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
2982
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
3083
dim3 gridDim(num_blocks, ne1, ne2);
31-
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
84+
if (dim == 0) {
85+
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
86+
return;
87+
}
88+
if (dim == 1) {
89+
concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
90+
return;
91+
}
92+
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
3293
}
3394

3495
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3596
const ggml_tensor * src0 = dst->src[0];
3697
const ggml_tensor * src1 = dst->src[1];
98+
3799
const float * src0_d = (const float *)src0->data;
38100
const float * src1_d = (const float *)src1->data;
101+
39102
float * dst_d = (float *)dst->data;
40103
cudaStream_t stream = ctx.stream();
41104

105+
const int32_t dim = ((int32_t *) dst->op_params)[0];
106+
107+
GGML_ASSERT(ggml_is_contiguous(src0));
108+
GGML_ASSERT(ggml_is_contiguous(src1));
109+
42110
GGML_ASSERT(src0->type == GGML_TYPE_F32);
43111
GGML_ASSERT(src1->type == GGML_TYPE_F32);
44112
GGML_ASSERT(dst->type == GGML_TYPE_F32);
45113

46-
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
47-
concat_f32_cuda(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], stream);
114+
if (dim != 3) {
115+
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
116+
concat_f32_cuda(
117+
src0_d + i3 * (src0->nb[3] / 4),
118+
src1_d + i3 * (src1->nb[3] / 4),
119+
dst_d + i3 * ( dst->nb[3] / 4),
120+
src0->ne[0], src0->ne[1], src0->ne[2],
121+
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
122+
}
123+
} else {
124+
const size_t size0 = ggml_nbytes(src0);
125+
const size_t size1 = ggml_nbytes(src1);
126+
127+
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
128+
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
48129
}
49130
}

ggml-metal.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,8 @@ static enum ggml_status ggml_metal_graph_compute(
990990
{
991991
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
992992

993+
const int32_t dim = ((int32_t *) dst->op_params)[0];
994+
993995
[encoder setComputePipelineState:pipeline];
994996
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
995997
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1018,6 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
10181020
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
10191021
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
10201022
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1023+
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
10211024

10221025
const int nth = MIN(1024, ne0);
10231026

ggml-metal.metal

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3366,31 +3366,30 @@ kernel void kernel_concat(
33663366
constant uint64_t & nb1,
33673367
constant uint64_t & nb2,
33683368
constant uint64_t & nb3,
3369+
constant int32_t & dim,
33693370
uint3 tgpig[[threadgroup_position_in_grid]],
33703371
uint3 tpitg[[thread_position_in_threadgroup]],
33713372
uint3 ntg[[threads_per_threadgroup]]) {
33723373

3373-
const int64_t i03 = tgpig.z;
3374-
const int64_t i02 = tgpig.y;
3375-
const int64_t i01 = tgpig.x;
3374+
const int64_t i3 = tgpig.z;
3375+
const int64_t i2 = tgpig.y;
3376+
const int64_t i1 = tgpig.x;
33763377

3377-
const int64_t i13 = i03 % ne13;
3378-
const int64_t i12 = i02 % ne12;
3379-
const int64_t i11 = i01 % ne11;
3378+
int64_t o[4] = {0, 0, 0, 0};
3379+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
33803380

3381-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
3382-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
3383-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
3381+
device const float * x;
33843382

33853383
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3386-
if (i02 < ne02) {
3387-
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
3388-
src0_ptr += ntg.x*nb00;
3384+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3385+
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
33893386
} else {
3390-
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
3391-
src1_ptr += ntg.x*nb10;
3387+
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
33923388
}
3393-
dst_ptr += ntg.x*nb0;
3389+
3390+
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3391+
3392+
*y = *x;
33943393
}
33953394
}
33963395

ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13512,6 +13512,10 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0,
1351213512
const float *src0_dd, const float *src1_dd,
1351313513
float *dst_dd,
1351413514
const dpct::queue_ptr &main_stream) {
13515+
#pragma message("TODO: generalize concat kernel for dim != 2")
13516+
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
13517+
int dim = dst->op_params[0];
13518+
GGML_ASSERT(dim != 2);
1351513519

1351613520
GGML_ASSERT(src0->type == GGML_TYPE_F32);
1351713521
GGML_ASSERT(src1->type == GGML_TYPE_F32);

ggml.c

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4882,18 +4882,31 @@ struct ggml_tensor * ggml_repeat_back(
48824882
// ggml_concat
48834883

48844884
struct ggml_tensor * ggml_concat(
4885-
struct ggml_context* ctx,
4886-
struct ggml_tensor* a,
4887-
struct ggml_tensor* b) {
4888-
GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
4885+
struct ggml_context * ctx,
4886+
struct ggml_tensor * a,
4887+
struct ggml_tensor * b,
4888+
int dim) {
4889+
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
4890+
4891+
int64_t ne[GGML_MAX_DIMS];
4892+
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
4893+
if (d == dim) {
4894+
ne[d] = a->ne[d] + b->ne[d];
4895+
continue;
4896+
}
4897+
GGML_ASSERT(a->ne[d] == b->ne[d]);
4898+
ne[d] = a->ne[d];
4899+
}
48894900

48904901
bool is_node = false;
48914902

48924903
if (a->grad || b->grad) {
48934904
is_node = true;
48944905
}
48954906

4896-
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
4907+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
4908+
4909+
ggml_set_op_params_i32(result, 0, dim);
48974910

48984911
result->op = GGML_OP_CONCAT;
48994912
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5013,6 +5026,7 @@ struct ggml_tensor * ggml_leaky_relu(
50135026
}
50145027

50155028
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5029+
50165030
ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
50175031

50185032
result->op = GGML_OP_LEAKY_RELU;
@@ -10967,34 +10981,37 @@ static void ggml_compute_forward_concat_f32(
1096710981
GGML_ASSERT(nb00 == sizeof(float));
1096810982
GGML_ASSERT(nb10 == sizeof(float));
1096910983

10984+
const int32_t dim = ggml_get_op_params_i32(dst, 0);
10985+
10986+
GGML_ASSERT(dim >= 0 && dim < 4);
10987+
10988+
int64_t o[4] = {0, 0, 0, 0};
10989+
o[dim] = src0->ne[dim];
10990+
10991+
const float * x;
10992+
10993+
// TODO: smarter multi-theading
1097010994
for (int i3 = 0; i3 < ne3; i3++) {
1097110995
for (int i2 = ith; i2 < ne2; i2 += nth) {
10972-
if (i2 < ne02) { // src0
10973-
for (int i1 = 0; i1 < ne1; i1++) {
10974-
for (int i0 = 0; i0 < ne0; i0++) {
10975-
const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
10976-
10977-
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10978-
*y = *x;
10979-
}
10980-
}
10981-
} // src1
10982-
else {
10983-
for (int i1 = 0; i1 < ne1; i1++) {
10984-
for (int i0 = 0; i0 < ne0; i0++) {
10985-
const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
10986-
10987-
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10988-
*y = *x;
10996+
for (int i1 = 0; i1 < ne1; i1++) {
10997+
for (int i0 = 0; i0 < ne0; i0++) {
10998+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
10999+
x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
11000+
} else {
11001+
x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
1098911002
}
11003+
11004+
float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
11005+
11006+
*y = *x;
1099011007
}
1099111008
}
1099211009
}
1099311010
}
1099411011
}
1099511012

1099611013
static void ggml_compute_forward_concat(
10997-
const struct ggml_compute_params* params,
11014+
const struct ggml_compute_params * params,
1099811015
struct ggml_tensor* dst) {
1099911016

1100011017
const struct ggml_tensor * src0 = dst->src[0];

ggml.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,12 +1007,13 @@ extern "C" {
10071007
struct ggml_tensor * a,
10081008
struct ggml_tensor * b);
10091009

1010-
// concat a and b on dim 2
1010+
// concat a and b along dim
10111011
// used in stable-diffusion
10121012
GGML_API struct ggml_tensor * ggml_concat(
10131013
struct ggml_context * ctx,
10141014
struct ggml_tensor * a,
1015-
struct ggml_tensor * b);
1015+
struct ggml_tensor * b,
1016+
int dim);
10161017

10171018
GGML_API struct ggml_tensor * ggml_abs(
10181019
struct ggml_context * ctx,

tests/test-backend-ops.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,22 +1259,26 @@ struct test_im2col : public test_case {
12591259
// GGML_OP_CONCAT
12601260
struct test_concat : public test_case {
12611261
const ggml_type type;
1262-
const std::array<int64_t, 4> ne;
1263-
const int64_t b_ne2;
1262+
const std::array<int64_t, 4> ne_a;
1263+
const int64_t ne_b_d;
1264+
const int dim;
12641265

12651266
std::string vars() override {
1266-
return VARS_TO_STR3(type, ne, b_ne2);
1267+
return VARS_TO_STR4(type, ne_a, ne_b_d, dim);
12671268
}
12681269

12691270
test_concat(ggml_type type = GGML_TYPE_F32,
1270-
std::array<int64_t, 4> ne = {10, 10, 10, 10},
1271-
int64_t b_ne2 = 10)
1272-
: type(type), ne(ne), b_ne2(b_ne2) {}
1271+
std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
1272+
int64_t ne_b_d = 10,
1273+
int dim = 2)
1274+
: type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {}
12731275

12741276
ggml_tensor * build_graph(ggml_context * ctx) override {
1275-
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1276-
ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], b_ne2, ne[3]);
1277-
ggml_tensor * out = ggml_concat(ctx, a, b);
1277+
auto ne_b = ne_a;
1278+
ne_b[dim] = ne_b_d;
1279+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1280+
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
1281+
ggml_tensor * out = ggml_concat(ctx, a, b, dim);
12781282
return out;
12791283
}
12801284
};
@@ -2211,8 +2215,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22112215
}
22122216
}
22132217

2214-
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
2215-
test_cases.emplace_back(new test_concat(GGML_TYPE_I32));
2218+
for (int dim : { 0, 1, 2, 3, }) {
2219+
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim));
2220+
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim));
2221+
}
22162222

22172223
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
22182224
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));

0 commit comments

Comments
 (0)