@@ -2609,6 +2609,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2609
2609
2610
2610
"SCALE" ,
2611
2611
"CPY" ,
2612
+ "CONT" ,
2612
2613
"RESHAPE" ,
2613
2614
"VIEW" ,
2614
2615
"PERMUTE" ,
@@ -2624,7 +2625,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2624
2625
"FLASH_FF" ,
2625
2626
};
2626
2627
2627
- static_assert (GGML_OP_COUNT == 35 , "GGML_OP_COUNT != 35 " );
2628
+ static_assert (GGML_OP_COUNT == 36 , "GGML_OP_COUNT != 36 " );
2628
2629
2629
2630
static const char * GGML_OP_SYMBOL [GGML_OP_COUNT ] = {
2630
2631
"none" ,
@@ -2653,6 +2654,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2653
2654
2654
2655
"x*v" ,
2655
2656
"x-\\>y" ,
2657
+ "cont(x)" ,
2656
2658
"reshape(x)" ,
2657
2659
"view(x)" ,
2658
2660
"permute(x)" ,
@@ -2668,7 +2670,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2668
2670
"flash_ff(x)" ,
2669
2671
};
2670
2672
2671
- static_assert (GGML_OP_COUNT == 35 , "GGML_OP_COUNT != 35 " );
2673
+ static_assert (GGML_OP_COUNT == 36 , "GGML_OP_COUNT != 36 " );
2672
2674
2673
2675
static_assert (sizeof (struct ggml_object )%GGML_MEM_ALIGN == 0 , "ggml_object size must be a multiple of GGML_MEM_ALIGN" );
2674
2676
static_assert (sizeof (struct ggml_tensor )%GGML_MEM_ALIGN == 0 , "ggml_tensor size must be a multiple of GGML_MEM_ALIGN" );
@@ -4301,6 +4303,41 @@ struct ggml_tensor * ggml_cpy_inplace(
4301
4303
return ggml_cpy_impl (ctx , a , b , true);
4302
4304
}
4303
4305
4306
+ // ggml_cont
4307
+
4308
+ struct ggml_tensor * ggml_cont_impl (
4309
+ struct ggml_context * ctx ,
4310
+ struct ggml_tensor * a ,
4311
+ bool inplace ) {
4312
+ bool is_node = false;
4313
+
4314
+ if (!inplace && a -> grad ) {
4315
+ GGML_ASSERT (false); // TODO: implement backward
4316
+ is_node = true;
4317
+ }
4318
+
4319
+ struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
4320
+
4321
+ result -> op = GGML_OP_CONT ;
4322
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
4323
+ result -> src0 = a ;
4324
+ result -> src1 = NULL ;
4325
+
4326
+ return result ;
4327
+ }
4328
+
4329
+ struct ggml_tensor * ggml_cont (
4330
+ struct ggml_context * ctx ,
4331
+ struct ggml_tensor * a ) {
4332
+ return ggml_cont_impl (ctx , a , false);
4333
+ }
4334
+
4335
+ struct ggml_tensor * ggml_cont_inplace (
4336
+ struct ggml_context * ctx ,
4337
+ struct ggml_tensor * a ) {
4338
+ return ggml_cont_impl (ctx , a , true);
4339
+ }
4340
+
4304
4341
// ggml_reshape
4305
4342
4306
4343
struct ggml_tensor * ggml_reshape (
@@ -4843,6 +4880,85 @@ static void ggml_compute_forward_dup_f16(
4843
4880
4844
4881
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
4845
4882
4883
+ if (ggml_is_contiguous (dst )) {
4884
+ if (src0 -> nb [0 ] == sizeof (ggml_fp16_t )) {
4885
+ if (dst -> type == GGML_TYPE_F16 ) {
4886
+ size_t id = 0 ;
4887
+ const size_t rs = ne00 * nb00 ;
4888
+
4889
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
4890
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
4891
+ for (int i01 = 0 ; i01 < ne01 ; i01 ++ ) {
4892
+ const char * src0_ptr = (char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 ;
4893
+ char * dst_ptr = (char * ) dst -> data + id * rs ;
4894
+
4895
+ memcpy (dst_ptr , src0_ptr , rs );
4896
+
4897
+ id ++ ;
4898
+ }
4899
+ }
4900
+ }
4901
+ } else if (dst -> type == GGML_TYPE_F32 ) {
4902
+ size_t id = 0 ;
4903
+ float * dst_ptr = (float * ) dst -> data ;
4904
+
4905
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
4906
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
4907
+ for (int i01 = 0 ; i01 < ne01 ; i01 ++ ) {
4908
+ for (int i00 = 0 ; i00 < ne00 ; i00 ++ ) {
4909
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03 );
4910
+
4911
+ dst_ptr [id ] = GGML_FP16_TO_FP32 (* src0_ptr );
4912
+ id ++ ;
4913
+ }
4914
+ }
4915
+ }
4916
+ }
4917
+ } else {
4918
+ GGML_ASSERT (false); // TODO: implement
4919
+ }
4920
+ } else {
4921
+ //printf("%s: this is not optimal - fix me\n", __func__);
4922
+
4923
+ if (dst -> type == GGML_TYPE_F32 ) {
4924
+ size_t id = 0 ;
4925
+ float * dst_ptr = (float * ) dst -> data ;
4926
+
4927
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
4928
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
4929
+ for (int i01 = 0 ; i01 < ne01 ; i01 ++ ) {
4930
+ for (int i00 = 0 ; i00 < ne00 ; i00 ++ ) {
4931
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03 );
4932
+
4933
+ dst_ptr [id ] = GGML_FP16_TO_FP32 (* src0_ptr );
4934
+ id ++ ;
4935
+ }
4936
+ }
4937
+ }
4938
+ }
4939
+ } else if (dst -> type == GGML_TYPE_F16 ) {
4940
+ size_t id = 0 ;
4941
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) dst -> data ;
4942
+
4943
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
4944
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
4945
+ for (int i01 = 0 ; i01 < ne01 ; i01 ++ ) {
4946
+ for (int i00 = 0 ; i00 < ne00 ; i00 ++ ) {
4947
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03 );
4948
+
4949
+ dst_ptr [id ] = * src0_ptr ;
4950
+ id ++ ;
4951
+ }
4952
+ }
4953
+ }
4954
+ }
4955
+ } else {
4956
+ GGML_ASSERT (false); // TODO: implement
4957
+ }
4958
+ }
4959
+ return ;
4960
+ }
4961
+
4846
4962
// dst counters
4847
4963
int64_t i10 = 0 ;
4848
4964
int64_t i11 = 0 ;
@@ -4937,6 +5053,105 @@ static void ggml_compute_forward_dup_f32(
4937
5053
return ;
4938
5054
}
4939
5055
5056
+ if (src0 -> type == dst -> type &&
5057
+ src0 -> ne [0 ] == dst -> ne [0 ] &&
5058
+ src0 -> nb [0 ] == GGML_TYPE_SIZE [src0 -> type ] && dst -> nb [0 ] == GGML_TYPE_SIZE [dst -> type ]) {
5059
+ // copy by rows
5060
+ const size_t rs = ne00 * nb00 ;
5061
+ for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
5062
+ for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
5063
+ for (int64_t i01 = 0 ; i01 < ne01 ; i01 ++ ) {
5064
+ memcpy (
5065
+ ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 ),
5066
+ ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 ),
5067
+ rs );
5068
+ }
5069
+ }
5070
+ }
5071
+ return ;
5072
+ }
5073
+
5074
+ if (ggml_is_contiguous (dst )) {
5075
+ // TODO: simplify
5076
+ if (src0 -> nb [0 ] == sizeof (float )) {
5077
+ if (dst -> type == GGML_TYPE_F32 ) {
5078
+ size_t id = 0 ;
5079
+ const size_t rs = ne00 * nb00 ;
5080
+
5081
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
5082
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
5083
+ for (int i01 = 0 ; i01 < ne01 ; i01 ++ ) {
5084
+ const char * src0_ptr = (char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 ;
5085
+ char * dst_ptr = (char * ) dst -> data + id * rs ;
5086
+
5087
+ memcpy (dst_ptr , src0_ptr , rs );
5088
+
5089
+ id ++ ;
5090
+ }
5091
+ }
5092
+ }
5093
+ } else if (dst -> type == GGML_TYPE_F16 ) {
5094
+ size_t id = 0 ;
5095
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) dst -> data ;
5096
+
5097
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
5098
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
5099
+ for (int i01 = 0 ; i01 < ne01 ; i01 ++ ) {
5100
+ for (int i00 = 0 ; i00 < ne00 ; i00 ++ ) {
5101
+ const float * src0_ptr = (float * ) ((char * ) src0 -> data + i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03 );
5102
+
5103
+ dst_ptr [id ] = GGML_FP32_TO_FP16 (* src0_ptr );
5104
+ id ++ ;
5105
+ }
5106
+ }
5107
+ }
5108
+ }
5109
+ } else {
5110
+ GGML_ASSERT (false); // TODO: implement
5111
+ }
5112
+ } else {
5113
+ //printf("%s: this is not optimal - fix me\n", __func__);
5114
+
5115
+ if (dst -> type == GGML_TYPE_F32 ) {
5116
+ size_t id = 0 ;
5117
+ float * dst_ptr = (float * ) dst -> data ;
5118
+
5119
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
5120
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
5121
+ for (int i01 = 0 ; i01 < ne01 ; i01 ++ ) {
5122
+ for (int i00 = 0 ; i00 < ne00 ; i00 ++ ) {
5123
+ const float * src0_ptr = (float * ) ((char * ) src0 -> data + i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03 );
5124
+
5125
+ dst_ptr [id ] = * src0_ptr ;
5126
+ id ++ ;
5127
+ }
5128
+ }
5129
+ }
5130
+ }
5131
+ } else if (dst -> type == GGML_TYPE_F16 ) {
5132
+ size_t id = 0 ;
5133
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) dst -> data ;
5134
+
5135
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
5136
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
5137
+ for (int i01 = 0 ; i01 < ne01 ; i01 ++ ) {
5138
+ for (int i00 = 0 ; i00 < ne00 ; i00 ++ ) {
5139
+ const float * src0_ptr = (float * ) ((char * ) src0 -> data + i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03 );
5140
+
5141
+ dst_ptr [id ] = GGML_FP32_TO_FP16 (* src0_ptr );
5142
+ id ++ ;
5143
+ }
5144
+ }
5145
+ }
5146
+ }
5147
+ } else {
5148
+ GGML_ASSERT (false); // TODO: implement
5149
+ }
5150
+ }
5151
+
5152
+ return ;
5153
+ }
5154
+
4940
5155
// dst counters
4941
5156
int64_t i10 = 0 ;
4942
5157
int64_t i11 = 0 ;
@@ -5057,14 +5272,18 @@ static void ggml_compute_forward_add_f32(
5057
5272
GGML_ASSERT (nb00 == sizeof (float ));
5058
5273
5059
5274
if (nb10 == sizeof (float )) {
5060
- const int j0 = (n /nth )* ith ;
5061
- const int j1 = ith == nth - 1 ? n : (n /nth )* (ith + 1 );
5062
-
5063
- for (int j = j0 ; j < j1 ; j ++ ) {
5275
+ for (int j = ith ; j < n ; j += nth ) {
5276
+ #ifdef GGML_USE_ACCELERATE
5277
+ vDSP_vadd (
5278
+ (float * ) ((char * ) src0 -> data + j * nb01 ), 1 ,
5279
+ (float * ) ((char * ) src1 -> data + j * nb11 ), 1 ,
5280
+ (float * ) ((char * ) dst -> data + j * nb1 ), 1 , nc );
5281
+ #else
5064
5282
ggml_vec_add_f32 (nc ,
5065
5283
(float * ) ((char * ) dst -> data + j * nb1 ),
5066
5284
(float * ) ((char * ) src0 -> data + j * nb01 ),
5067
5285
(float * ) ((char * ) src1 -> data + j * nb11 ));
5286
+ #endif
5068
5287
}
5069
5288
} else {
5070
5289
// src1 is not contiguous
@@ -6812,6 +7031,15 @@ static void ggml_compute_forward_cpy(
6812
7031
ggml_compute_forward_dup (params , src0 , dst );
6813
7032
}
6814
7033
7034
+ // ggml_compute_forward_cont
7035
+
7036
+ static void ggml_compute_forward_cont (
7037
+ const struct ggml_compute_params * params ,
7038
+ const struct ggml_tensor * src0 ,
7039
+ struct ggml_tensor * dst ) {
7040
+ ggml_compute_forward_dup (params , src0 , dst );
7041
+ }
7042
+
6815
7043
// ggml_compute_forward_reshape
6816
7044
6817
7045
static void ggml_compute_forward_reshape (
@@ -8642,6 +8870,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8642
8870
{
8643
8871
ggml_compute_forward_cpy (params , tensor -> src0 , tensor );
8644
8872
} break ;
8873
+ case GGML_OP_CONT :
8874
+ {
8875
+ ggml_compute_forward_cont (params , tensor -> src0 , tensor );
8876
+ } break ;
8645
8877
case GGML_OP_RESHAPE :
8646
8878
{
8647
8879
ggml_compute_forward_reshape (params , tensor -> src0 , tensor );
@@ -8886,8 +9118,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8886
9118
src1 -> grad =
8887
9119
ggml_add_impl (ctx ,
8888
9120
src1 -> grad ,
8889
- // TODO: fix transpose, the node will break the graph connections
8890
- ggml_mul_mat (ctx , ggml_transpose (ctx , src0 ), tensor -> grad ),
9121
+ ggml_mul_mat (ctx ,
9122
+ ggml_cont (ctx , ggml_transpose (ctx , src0 )),
9123
+ tensor -> grad ),
8891
9124
inplace );
8892
9125
}
8893
9126
} break ;
@@ -8899,6 +9132,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8899
9132
{
8900
9133
GGML_ASSERT (false); // TODO: not implemented
8901
9134
} break ;
9135
+ case GGML_OP_CONT :
9136
+ {
9137
+ GGML_ASSERT (false); // TODO: not implemented
9138
+ } break ;
8902
9139
case GGML_OP_RESHAPE :
8903
9140
{
8904
9141
GGML_ASSERT (false); // TODO: not implemented
@@ -9353,6 +9590,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9353
9590
node -> n_tasks = n_threads ;
9354
9591
} break ;
9355
9592
case GGML_OP_CPY :
9593
+ case GGML_OP_CONT :
9356
9594
case GGML_OP_RESHAPE :
9357
9595
case GGML_OP_VIEW :
9358
9596
case GGML_OP_PERMUTE :
0 commit comments