@@ -186,6 +186,7 @@ struct vk_device_struct {
186
186
vk_pipeline pipeline_sqr_f32;
187
187
vk_pipeline pipeline_clamp_f32;
188
188
vk_pipeline pipeline_pad_f32;
189
+ vk_pipeline pipeline_repeat_f32;
189
190
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
190
191
vk_pipeline pipeline_norm_f32;
191
192
vk_pipeline pipeline_group_norm_f32;
@@ -1659,6 +1660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1659
1660
1660
1661
ggml_vk_create_pipeline (device, device->pipeline_pad_f32 , " pad_f32" , pad_f32_len, pad_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
1661
1662
1663
+ ggml_vk_create_pipeline (device, device->pipeline_repeat_f32 , " repeat_f32" , repeat_f32_len, repeat_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
1664
+
1662
1665
ggml_vk_create_pipeline (device, device->pipeline_gelu_f32 , " gelu_f32" , gelu_f32_len, gelu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
1663
1666
ggml_vk_create_pipeline (device, device->pipeline_gelu_quick_f32 , " gelu_quick_f32" , gelu_quick_f32_len, gelu_quick_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
1664
1667
ggml_vk_create_pipeline (device, device->pipeline_silu_f32 , " silu_f32" , silu_f32_len, silu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
@@ -3847,76 +3850,6 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
3847
3850
}
3848
3851
}
3849
3852
3850
- static void ggml_vk_op_repeat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3851
- VK_LOG_DEBUG (" ggml_vk_op_repeat(" << src0 << " , " << src1 << " , " << dst << " )" );
3852
- const uint64_t ne0 = dst->ne [0 ];
3853
- const uint64_t ne1 = dst->ne [1 ];
3854
- const uint64_t ne2 = dst->ne [2 ];
3855
- const uint64_t ne3 = dst->ne [3 ];
3856
-
3857
- const uint64_t ne00 = src0->ne [0 ];
3858
- const uint64_t ne01 = src0->ne [1 ];
3859
- const uint64_t ne02 = src0->ne [2 ];
3860
- const uint64_t ne03 = src0->ne [3 ];
3861
-
3862
- const uint64_t nb0 = dst->nb [0 ];
3863
- const uint64_t nb1 = dst->nb [1 ];
3864
- const uint64_t nb2 = dst->nb [2 ];
3865
- const uint64_t nb3 = dst->nb [3 ];
3866
-
3867
- const uint64_t nb00 = src0->nb [0 ];
3868
- const uint64_t nb01 = src0->nb [1 ];
3869
- const uint64_t nb02 = src0->nb [2 ];
3870
- const uint64_t nb03 = src0->nb [3 ];
3871
-
3872
- // guaranteed to be an integer due to the check in ggml_can_repeat
3873
- const uint64_t nr0 = ne0/ne00;
3874
- const uint64_t nr1 = ne1/ne01;
3875
- const uint64_t nr2 = ne2/ne02;
3876
- const uint64_t nr3 = ne3/ne03;
3877
-
3878
- // TODO: support for transposed / permuted tensors
3879
- GGML_ASSERT (nb0 == sizeof (float ));
3880
- GGML_ASSERT (nb00 == sizeof (float ));
3881
-
3882
- ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra ;
3883
- ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra ;
3884
-
3885
- const vk_buffer src_buf = extra_src0->buffer_gpu .lock ();
3886
- const uint64_t src_offset = extra_src0->offset + src0->view_offs ;
3887
- vk_buffer dst_buf = extra->buffer_gpu .lock ();
3888
- const uint64_t dst_offset = extra->offset + dst->view_offs ;
3889
-
3890
- std::vector<vk::BufferCopy> copies;
3891
-
3892
- for (uint64_t i3 = 0 ; i3 < nr3; i3++) {
3893
- for (uint64_t k3 = 0 ; k3 < ne03; k3++) {
3894
- for (uint64_t i2 = 0 ; i2 < nr2; i2++) {
3895
- for (uint64_t k2 = 0 ; k2 < ne02; k2++) {
3896
- for (uint64_t i1 = 0 ; i1 < nr1; i1++) {
3897
- for (uint64_t k1 = 0 ; k1 < ne01; k1++) {
3898
- for (uint64_t i0 = 0 ; i0 < nr0; i0++) {
3899
- copies.push_back ({
3900
- src_offset + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
3901
- dst_offset + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
3902
- ne00*nb0,
3903
- });
3904
- }
3905
- }
3906
- }
3907
- }
3908
- }
3909
- }
3910
- }
3911
-
3912
- ggml_vk_sync_buffers (subctx);
3913
- subctx->s ->buffer .copyBuffer (src_buf->buffer , dst_buf->buffer , copies);
3914
-
3915
- GGML_UNUSED (ctx);
3916
- GGML_UNUSED (src1);
3917
- }
3918
-
3919
-
3920
3853
static vk_pipeline ggml_vk_op_get_pipeline (ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
3921
3854
switch (op) {
3922
3855
case GGML_OP_GET_ROWS:
@@ -3982,6 +3915,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3982
3915
return ctx->device ->pipeline_pad_f32 ;
3983
3916
}
3984
3917
return nullptr ;
3918
+ case GGML_OP_REPEAT:
3919
+ if (ggml_type_size (src0->type ) == sizeof (float ) && ggml_type_size (dst->type ) == sizeof (float )) {
3920
+ return ctx->device ->pipeline_repeat_f32 ;
3921
+ }
3922
+ return nullptr ;
3985
3923
case GGML_OP_CPY:
3986
3924
case GGML_OP_CONT:
3987
3925
case GGML_OP_DUP:
@@ -4104,15 +4042,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
4104
4042
GGML_UNUSED (src2);
4105
4043
}
4106
4044
4107
- static ggml_vk_func_t ggml_vk_op_get_func (ggml_op op) {
4108
- switch (op) {
4109
- case GGML_OP_REPEAT:
4110
- return ggml_vk_op_repeat;
4111
- default :
4112
- return nullptr ;
4113
- }
4114
- }
4115
-
4116
4045
static bool ggml_vk_op_supports_incontiguous (ggml_op op) {
4117
4046
switch (op) {
4118
4047
case GGML_OP_CPY:
@@ -4126,6 +4055,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
4126
4055
case GGML_OP_SQR:
4127
4056
case GGML_OP_CLAMP:
4128
4057
case GGML_OP_PAD:
4058
+ case GGML_OP_REPEAT:
4129
4059
return true ;
4130
4060
default :
4131
4061
return false ;
@@ -4173,21 +4103,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4173
4103
const uint64_t ned = ned0 * ned1;
4174
4104
4175
4105
vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, src0, src1, src2, dst, op);
4176
- ggml_vk_func_t op_func;
4177
4106
4178
4107
if (pipeline == nullptr ) {
4179
- op_func = ggml_vk_op_get_func (op);
4180
- if (op_func == nullptr ) {
4181
- std::cerr << " ggml_vulkan: Error: Missing op: " << ggml_op_name (op) << " for " << ggml_type_name (src0->type );
4182
- if (src1 != nullptr ) {
4183
- std::cerr << " and " << ggml_type_name (src1->type );
4184
- }
4185
- std::cerr << " to " << ggml_type_name (dst->type ) << std::endl;
4186
- GGML_ABORT (" fatal error" );
4108
+ std::cerr << " ggml_vulkan: Error: Missing op: " << ggml_op_name (op) << " for " << ggml_type_name (src0->type );
4109
+ if (src1 != nullptr ) {
4110
+ std::cerr << " and " << ggml_type_name (src1->type );
4187
4111
}
4188
-
4189
- op_func (ctx, subctx, src0, src1, dst);
4190
- return ;
4112
+ std::cerr << " to " << ggml_type_name (dst->type ) << std::endl;
4113
+ GGML_ABORT (" fatal error" );
4191
4114
}
4192
4115
4193
4116
const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous (op);
@@ -4337,6 +4260,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4337
4260
case GGML_OP_SQR:
4338
4261
case GGML_OP_CLAMP:
4339
4262
case GGML_OP_PAD:
4263
+ case GGML_OP_REPEAT:
4340
4264
case GGML_OP_CPY:
4341
4265
case GGML_OP_CONCAT:
4342
4266
case GGML_OP_UPSCALE:
@@ -4452,10 +4376,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4452
4376
}
4453
4377
}
4454
4378
4455
- static void ggml_vk_repeat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
4456
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_REPEAT, {});
4457
- }
4458
-
4459
4379
static void ggml_vk_get_rows (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4460
4380
const uint32_t src0_type_size = ggml_type_size (src0->type );
4461
4381
const uint32_t src1_type_size = ggml_type_size (src1->type );
@@ -4603,6 +4523,19 @@ static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const
4603
4523
});
4604
4524
}
4605
4525
4526
+ static void ggml_vk_repeat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
4527
+ const uint32_t src0_type_size = ggml_type_size (src0->type );
4528
+ const uint32_t dst_type_size = ggml_type_size (dst->type );
4529
+
4530
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_REPEAT, {
4531
+ (uint32_t )ggml_nelements (dst),
4532
+ (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ], (uint32_t )src0->ne [3 ], (uint32_t )src0->nb [0 ] / src0_type_size, (uint32_t )src0->nb [1 ] / src0_type_size, (uint32_t )src0->nb [2 ] / src0_type_size, (uint32_t )src0->nb [3 ] / src0_type_size,
4533
+ (uint32_t ) dst->ne [0 ], (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ], (uint32_t ) dst->ne [3 ], (uint32_t ) dst->nb [0 ] / dst_type_size, (uint32_t ) dst->nb [1 ] / dst_type_size, (uint32_t ) dst->nb [2 ] / dst_type_size, (uint32_t ) dst->nb [3 ] / dst_type_size,
4534
+ 0 ,
4535
+ 0 .0f , 0 .0f ,
4536
+ });
4537
+ }
4538
+
4606
4539
static void ggml_vk_cpy (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
4607
4540
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra ;
4608
4541
const uint32_t src0_type_size = ggml_type_size (src0->type );
@@ -6637,10 +6570,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
6637
6570
return false ;
6638
6571
} break ;
6639
6572
case GGML_OP_REPEAT:
6640
- {
6641
- ggml_type src0_type = op->src [0 ]->type ;
6642
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
6643
- } break ;
6573
+ return ggml_type_size (op->type ) == sizeof (float ) && ggml_type_size (op->src [0 ]->type ) == sizeof (float );
6644
6574
case GGML_OP_ROPE:
6645
6575
return ggml_is_contiguous (op->src [0 ]);
6646
6576
case GGML_OP_NONE:
@@ -7104,6 +7034,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7104
7034
tensor_clone = ggml_clamp (ggml_ctx, src0_clone, ((float *)tensor->op_params )[0 ], ((float *)tensor->op_params )[1 ]);
7105
7035
} else if (tensor->op == GGML_OP_PAD) {
7106
7036
tensor_clone = ggml_pad (ggml_ctx, src0_clone, tensor->ne [0 ] - src0_clone->ne [0 ], tensor->ne [1 ] - src0_clone->ne [1 ], tensor->ne [2 ] - src0_clone->ne [2 ], tensor->ne [3 ] - src0_clone->ne [3 ]);
7037
+ } else if (tensor->op == GGML_OP_REPEAT) {
7038
+ tensor_clone = ggml_repeat (ggml_ctx, src0_clone, src1_clone);
7107
7039
} else if (tensor->op == GGML_OP_ADD) {
7108
7040
tensor_clone = ggml_add (ggml_ctx, src0_clone, src1_clone);
7109
7041
} else if (tensor->op == GGML_OP_NORM) {
0 commit comments