Skip to content

Commit cce3dcf

Browse files
authored
cuda : non-cont concat support (#7610)
* tests : add non-cont concat tests * cuda : non-cont concat support ggml-ci
1 parent 210d991 commit cce3dcf

File tree

2 files changed

+113
-30
lines changed

2 files changed

+113
-30
lines changed

ggml-cuda/concat.cu

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "concat.cuh"
22

3+
// contiguous kernels
34
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
45
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
56
if (nidx >= ne0) {
@@ -92,39 +93,104 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
9293
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
9394
}
9495

96+
// non-contiguous kernel (slow)
97+
static __global__ void concat_f32_non_cont(
98+
const char * src0,
99+
const char * src1,
100+
char * dst,
101+
int64_t ne00,
102+
int64_t ne01,
103+
int64_t ne02,
104+
int64_t ne03,
105+
uint64_t nb00,
106+
uint64_t nb01,
107+
uint64_t nb02,
108+
uint64_t nb03,
109+
int64_t /*ne10*/,
110+
int64_t /*ne11*/,
111+
int64_t /*ne12*/,
112+
int64_t /*ne13*/,
113+
uint64_t nb10,
114+
uint64_t nb11,
115+
uint64_t nb12,
116+
uint64_t nb13,
117+
int64_t ne0,
118+
int64_t /*ne1*/,
119+
int64_t /*ne2*/,
120+
int64_t /*ne3*/,
121+
uint64_t nb0,
122+
uint64_t nb1,
123+
uint64_t nb2,
124+
uint64_t nb3,
125+
int32_t dim) {
126+
const int64_t i3 = blockIdx.z;
127+
const int64_t i2 = blockIdx.y;
128+
const int64_t i1 = blockIdx.x;
129+
130+
int64_t o[4] = {0, 0, 0, 0};
131+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
132+
133+
const float * x;
134+
135+
for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
136+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
137+
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
138+
} else {
139+
x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
140+
}
141+
142+
float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
143+
144+
*y = *x;
145+
}
146+
}
147+
148+
95149
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
96150
const ggml_tensor * src0 = dst->src[0];
97151
const ggml_tensor * src1 = dst->src[1];
98152

99-
const float * src0_d = (const float *)src0->data;
100-
const float * src1_d = (const float *)src1->data;
101-
102-
float * dst_d = (float *)dst->data;
103153
cudaStream_t stream = ctx.stream();
104154

105155
const int32_t dim = ((int32_t *) dst->op_params)[0];
106156

107-
GGML_ASSERT(ggml_is_contiguous(src0));
108-
GGML_ASSERT(ggml_is_contiguous(src1));
109-
110157
GGML_ASSERT(src0->type == GGML_TYPE_F32);
111158
GGML_ASSERT(src1->type == GGML_TYPE_F32);
112-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
113-
114-
if (dim != 3) {
115-
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
116-
concat_f32_cuda(
117-
src0_d + i3 * (src0->nb[3] / 4),
118-
src1_d + i3 * (src1->nb[3] / 4),
119-
dst_d + i3 * ( dst->nb[3] / 4),
120-
src0->ne[0], src0->ne[1], src0->ne[2],
121-
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
159+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
160+
161+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
162+
const float * src0_d = (const float *)src0->data;
163+
const float * src1_d = (const float *)src1->data;
164+
165+
float * dst_d = (float *)dst->data;
166+
167+
if (dim != 3) {
168+
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
169+
concat_f32_cuda(
170+
src0_d + i3 * (src0->nb[3] / 4),
171+
src1_d + i3 * (src1->nb[3] / 4),
172+
dst_d + i3 * ( dst->nb[3] / 4),
173+
src0->ne[0], src0->ne[1], src0->ne[2],
174+
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
175+
}
176+
} else {
177+
const size_t size0 = ggml_nbytes(src0);
178+
const size_t size1 = ggml_nbytes(src1);
179+
180+
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
181+
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
122182
}
123183
} else {
124-
const size_t size0 = ggml_nbytes(src0);
125-
const size_t size1 = ggml_nbytes(src1);
126-
127-
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
128-
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
184+
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
185+
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
186+
(const char *)src0->data,
187+
(const char *)src1->data,
188+
( char *)dst->data,
189+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
190+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
191+
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
192+
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
193+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
194+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
129195
}
130196
}

tests/test-backend-ops.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,22 +1262,37 @@ struct test_concat : public test_case {
12621262
const std::array<int64_t, 4> ne_a;
12631263
const int64_t ne_b_d;
12641264
const int dim;
1265+
const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b)
12651266

12661267
std::string vars() override {
1267-
return VARS_TO_STR4(type, ne_a, ne_b_d, dim);
1268+
return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v);
12681269
}
12691270

12701271
test_concat(ggml_type type = GGML_TYPE_F32,
12711272
std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
12721273
int64_t ne_b_d = 10,
1273-
int dim = 2)
1274-
: type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {}
1274+
int dim = 2, int v = 0)
1275+
: type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}
12751276

12761277
ggml_tensor * build_graph(ggml_context * ctx) override {
12771278
auto ne_b = ne_a;
12781279
ne_b[dim] = ne_b_d;
1279-
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1280-
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
1280+
ggml_tensor * a;
1281+
if (v & 1) {
1282+
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
1283+
a = ggml_new_tensor(ctx, type, 4, ne.data());
1284+
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
1285+
} else {
1286+
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1287+
}
1288+
ggml_tensor * b;
1289+
if (v & 2) {
1290+
auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;
1291+
b = ggml_new_tensor(ctx, type, 4, ne.data());
1292+
b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);
1293+
} else {
1294+
b = ggml_new_tensor(ctx, type, 4, ne_b.data());
1295+
}
12811296
ggml_tensor * out = ggml_concat(ctx, a, b, dim);
12821297
return out;
12831298
}
@@ -2215,9 +2230,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22152230
}
22162231
}
22172232

2218-
for (int dim : { 0, 1, 2, 3, }) {
2219-
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim));
2220-
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim));
2233+
for (int v : { 0, 1, 2, 3 }) {
2234+
for (int dim : { 0, 1, 2, 3, }) {
2235+
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
2236+
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));
2237+
}
22212238
}
22222239

22232240
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {

0 commit comments

Comments
 (0)