Skip to content

Commit c3ac702

Browse files
committed
ggml : add ggml_cont() + optimize ggml_cpy() for contiguous dst
1 parent 9d634ef commit c3ac702

File tree

2 files changed

+252
-8
lines changed

2 files changed

+252
-8
lines changed

ggml.c

Lines changed: 246 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,6 +2609,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
26092609

26102610
"SCALE",
26112611
"CPY",
2612+
"CONT",
26122613
"RESHAPE",
26132614
"VIEW",
26142615
"PERMUTE",
@@ -2624,7 +2625,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
26242625
"FLASH_FF",
26252626
};
26262627

2627-
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2628+
static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
26282629

26292630
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
26302631
"none",
@@ -2653,6 +2654,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
26532654

26542655
"x*v",
26552656
"x-\\>y",
2657+
"cont(x)",
26562658
"reshape(x)",
26572659
"view(x)",
26582660
"permute(x)",
@@ -2668,7 +2670,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
26682670
"flash_ff(x)",
26692671
};
26702672

2671-
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2673+
static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
26722674

26732675
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
26742676
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -4301,6 +4303,41 @@ struct ggml_tensor * ggml_cpy_inplace(
43014303
return ggml_cpy_impl(ctx, a, b, true);
43024304
}
43034305

4306+
// ggml_cont
4307+
4308+
struct ggml_tensor * ggml_cont_impl(
4309+
struct ggml_context * ctx,
4310+
struct ggml_tensor * a,
4311+
bool inplace) {
4312+
bool is_node = false;
4313+
4314+
if (!inplace && a->grad) {
4315+
GGML_ASSERT(false); // TODO: implement backward
4316+
is_node = true;
4317+
}
4318+
4319+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4320+
4321+
result->op = GGML_OP_CONT;
4322+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4323+
result->src0 = a;
4324+
result->src1 = NULL;
4325+
4326+
return result;
4327+
}
4328+
4329+
struct ggml_tensor * ggml_cont(
4330+
struct ggml_context * ctx,
4331+
struct ggml_tensor * a) {
4332+
return ggml_cont_impl(ctx, a, false);
4333+
}
4334+
4335+
struct ggml_tensor * ggml_cont_inplace(
4336+
struct ggml_context * ctx,
4337+
struct ggml_tensor * a) {
4338+
return ggml_cont_impl(ctx, a, true);
4339+
}
4340+
43044341
// ggml_reshape
43054342

43064343
struct ggml_tensor * ggml_reshape(
@@ -4843,6 +4880,85 @@ static void ggml_compute_forward_dup_f16(
48434880

48444881
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
48454882

4883+
if (ggml_is_contiguous(dst)) {
4884+
if (src0->nb[0] == sizeof(ggml_fp16_t)) {
4885+
if (dst->type == GGML_TYPE_F16) {
4886+
size_t id = 0;
4887+
const size_t rs = ne00*nb00;
4888+
4889+
for (int i03 = 0; i03 < ne03; i03++) {
4890+
for (int i02 = 0; i02 < ne02; i02++) {
4891+
for (int i01 = 0; i01 < ne01; i01++) {
4892+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
4893+
char * dst_ptr = (char *) dst->data + id*rs;
4894+
4895+
memcpy(dst_ptr, src0_ptr, rs);
4896+
4897+
id++;
4898+
}
4899+
}
4900+
}
4901+
} else if (dst->type == GGML_TYPE_F32) {
4902+
size_t id = 0;
4903+
float * dst_ptr = (float *) dst->data;
4904+
4905+
for (int i03 = 0; i03 < ne03; i03++) {
4906+
for (int i02 = 0; i02 < ne02; i02++) {
4907+
for (int i01 = 0; i01 < ne01; i01++) {
4908+
for (int i00 = 0; i00 < ne00; i00++) {
4909+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4910+
4911+
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4912+
id++;
4913+
}
4914+
}
4915+
}
4916+
}
4917+
} else {
4918+
GGML_ASSERT(false); // TODO: implement
4919+
}
4920+
} else {
4921+
//printf("%s: this is not optimal - fix me\n", __func__);
4922+
4923+
if (dst->type == GGML_TYPE_F32) {
4924+
size_t id = 0;
4925+
float * dst_ptr = (float *) dst->data;
4926+
4927+
for (int i03 = 0; i03 < ne03; i03++) {
4928+
for (int i02 = 0; i02 < ne02; i02++) {
4929+
for (int i01 = 0; i01 < ne01; i01++) {
4930+
for (int i00 = 0; i00 < ne00; i00++) {
4931+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4932+
4933+
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4934+
id++;
4935+
}
4936+
}
4937+
}
4938+
}
4939+
} else if (dst->type == GGML_TYPE_F16) {
4940+
size_t id = 0;
4941+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4942+
4943+
for (int i03 = 0; i03 < ne03; i03++) {
4944+
for (int i02 = 0; i02 < ne02; i02++) {
4945+
for (int i01 = 0; i01 < ne01; i01++) {
4946+
for (int i00 = 0; i00 < ne00; i00++) {
4947+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4948+
4949+
dst_ptr[id] = *src0_ptr;
4950+
id++;
4951+
}
4952+
}
4953+
}
4954+
}
4955+
} else {
4956+
GGML_ASSERT(false); // TODO: implement
4957+
}
4958+
}
4959+
return;
4960+
}
4961+
48464962
// dst counters
48474963
int64_t i10 = 0;
48484964
int64_t i11 = 0;
@@ -4937,6 +5053,105 @@ static void ggml_compute_forward_dup_f32(
49375053
return;
49385054
}
49395055

5056+
if (src0->type == dst->type &&
5057+
src0->ne[0] == dst->ne[0] &&
5058+
src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5059+
// copy by rows
5060+
const size_t rs = ne00*nb00;
5061+
for (int64_t i03 = 0; i03 < ne03; i03++) {
5062+
for (int64_t i02 = 0; i02 < ne02; i02++) {
5063+
for (int64_t i01 = 0; i01 < ne01; i01++) {
5064+
memcpy(
5065+
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5066+
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
5067+
rs);
5068+
}
5069+
}
5070+
}
5071+
return;
5072+
}
5073+
5074+
if (ggml_is_contiguous(dst)) {
5075+
// TODO: simplify
5076+
if (src0->nb[0] == sizeof(float)) {
5077+
if (dst->type == GGML_TYPE_F32) {
5078+
size_t id = 0;
5079+
const size_t rs = ne00*nb00;
5080+
5081+
for (int i03 = 0; i03 < ne03; i03++) {
5082+
for (int i02 = 0; i02 < ne02; i02++) {
5083+
for (int i01 = 0; i01 < ne01; i01++) {
5084+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5085+
char * dst_ptr = (char *) dst->data + id*rs;
5086+
5087+
memcpy(dst_ptr, src0_ptr, rs);
5088+
5089+
id++;
5090+
}
5091+
}
5092+
}
5093+
} else if (dst->type == GGML_TYPE_F16) {
5094+
size_t id = 0;
5095+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5096+
5097+
for (int i03 = 0; i03 < ne03; i03++) {
5098+
for (int i02 = 0; i02 < ne02; i02++) {
5099+
for (int i01 = 0; i01 < ne01; i01++) {
5100+
for (int i00 = 0; i00 < ne00; i00++) {
5101+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5102+
5103+
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5104+
id++;
5105+
}
5106+
}
5107+
}
5108+
}
5109+
} else {
5110+
GGML_ASSERT(false); // TODO: implement
5111+
}
5112+
} else {
5113+
//printf("%s: this is not optimal - fix me\n", __func__);
5114+
5115+
if (dst->type == GGML_TYPE_F32) {
5116+
size_t id = 0;
5117+
float * dst_ptr = (float *) dst->data;
5118+
5119+
for (int i03 = 0; i03 < ne03; i03++) {
5120+
for (int i02 = 0; i02 < ne02; i02++) {
5121+
for (int i01 = 0; i01 < ne01; i01++) {
5122+
for (int i00 = 0; i00 < ne00; i00++) {
5123+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5124+
5125+
dst_ptr[id] = *src0_ptr;
5126+
id++;
5127+
}
5128+
}
5129+
}
5130+
}
5131+
} else if (dst->type == GGML_TYPE_F16) {
5132+
size_t id = 0;
5133+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5134+
5135+
for (int i03 = 0; i03 < ne03; i03++) {
5136+
for (int i02 = 0; i02 < ne02; i02++) {
5137+
for (int i01 = 0; i01 < ne01; i01++) {
5138+
for (int i00 = 0; i00 < ne00; i00++) {
5139+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5140+
5141+
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5142+
id++;
5143+
}
5144+
}
5145+
}
5146+
}
5147+
} else {
5148+
GGML_ASSERT(false); // TODO: implement
5149+
}
5150+
}
5151+
5152+
return;
5153+
}
5154+
49405155
// dst counters
49415156
int64_t i10 = 0;
49425157
int64_t i11 = 0;
@@ -5057,14 +5272,18 @@ static void ggml_compute_forward_add_f32(
50575272
GGML_ASSERT(nb00 == sizeof(float));
50585273

50595274
if (nb10 == sizeof(float)) {
5060-
const int j0 = (n/nth)*ith;
5061-
const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
5062-
5063-
for (int j = j0; j < j1; j++) {
5275+
for (int j = ith; j < n; j += nth) {
5276+
#ifdef GGML_USE_ACCELERATE
5277+
vDSP_vadd(
5278+
(float *) ((char *) src0->data + j*nb01), 1,
5279+
(float *) ((char *) src1->data + j*nb11), 1,
5280+
(float *) ((char *) dst->data + j*nb1), 1, nc);
5281+
#else
50645282
ggml_vec_add_f32(nc,
50655283
(float *) ((char *) dst->data + j*nb1),
50665284
(float *) ((char *) src0->data + j*nb01),
50675285
(float *) ((char *) src1->data + j*nb11));
5286+
#endif
50685287
}
50695288
} else {
50705289
// src1 is not contiguous
@@ -6812,6 +7031,15 @@ static void ggml_compute_forward_cpy(
68127031
ggml_compute_forward_dup(params, src0, dst);
68137032
}
68147033

7034+
// ggml_compute_forward_cont
7035+
7036+
static void ggml_compute_forward_cont(
7037+
const struct ggml_compute_params * params,
7038+
const struct ggml_tensor * src0,
7039+
struct ggml_tensor * dst) {
7040+
ggml_compute_forward_dup(params, src0, dst);
7041+
}
7042+
68157043
// ggml_compute_forward_reshape
68167044

68177045
static void ggml_compute_forward_reshape(
@@ -8642,6 +8870,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
86428870
{
86438871
ggml_compute_forward_cpy(params, tensor->src0, tensor);
86448872
} break;
8873+
case GGML_OP_CONT:
8874+
{
8875+
ggml_compute_forward_cont(params, tensor->src0, tensor);
8876+
} break;
86458877
case GGML_OP_RESHAPE:
86468878
{
86478879
ggml_compute_forward_reshape(params, tensor->src0, tensor);
@@ -8886,8 +9118,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
88869118
src1->grad =
88879119
ggml_add_impl(ctx,
88889120
src1->grad,
8889-
// TODO: fix transpose, the node will break the graph connections
8890-
ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
9121+
ggml_mul_mat(ctx,
9122+
ggml_cont(ctx, ggml_transpose(ctx, src0)),
9123+
tensor->grad),
88919124
inplace);
88929125
}
88939126
} break;
@@ -8899,6 +9132,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
88999132
{
89009133
GGML_ASSERT(false); // TODO: not implemented
89019134
} break;
9135+
case GGML_OP_CONT:
9136+
{
9137+
GGML_ASSERT(false); // TODO: not implemented
9138+
} break;
89029139
case GGML_OP_RESHAPE:
89039140
{
89049141
GGML_ASSERT(false); // TODO: not implemented
@@ -9353,6 +9590,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
93539590
node->n_tasks = n_threads;
93549591
} break;
93559592
case GGML_OP_CPY:
9593+
case GGML_OP_CONT:
93569594
case GGML_OP_RESHAPE:
93579595
case GGML_OP_VIEW:
93589596
case GGML_OP_PERMUTE:

ggml.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ enum ggml_op {
236236

237237
GGML_OP_SCALE,
238238
GGML_OP_CPY,
239+
GGML_OP_CONT,
239240
GGML_OP_RESHAPE,
240241
GGML_OP_VIEW,
241242
GGML_OP_PERMUTE,
@@ -525,6 +526,11 @@ struct ggml_tensor * ggml_cpy(
525526
struct ggml_tensor * a,
526527
struct ggml_tensor * b);
527528

529+
// make contiguous
530+
struct ggml_tensor * ggml_cont(
531+
struct ggml_context * ctx,
532+
struct ggml_tensor * a);
533+
528534
// return view(a), b specifies the new shape
529535
// TODO: when we start computing gradient, make a copy instead of view
530536
struct ggml_tensor * ggml_reshape(

0 commit comments

Comments
 (0)