@@ -3180,11 +3180,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
31803180
31813181void ggml_cann_flash_attn_ext (ggml_backend_cann_context& ctx, ggml_tensor* dst){
31823182
3183- ggml_tensor* src0 = dst->src [0 ]; // q, fp32
3184- ggml_tensor* src1 = dst->src [1 ]; // k, fp16
3185- ggml_tensor* src2 = dst->src [2 ]; // v, fp16
3183+ ggml_tensor* src0 = dst->src [0 ]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
3184+ ggml_tensor* src1 = dst->src [1 ]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
3185+ ggml_tensor* src2 = dst->src [2 ]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
31863186 ggml_tensor* src3 = dst->src [3 ]; // mask, fp16
31873187
3188+ // B, N, S, D (uncont) -> B, S, N, D (cont)
3189+ int64_t src0_bsnd_ne[GGML_MAX_DIMS];
3190+ memcpy (src0_bsnd_ne, src0->ne , GGML_MAX_DIMS * sizeof (int64_t ));
3191+ size_t src0_bsnd_nb[GGML_MAX_DIMS];
3192+ memcpy (src0_bsnd_nb, src0->nb , GGML_MAX_DIMS * sizeof (size_t ));
3193+ int64_t src1_bsnd_ne[GGML_MAX_DIMS];
3194+ memcpy (src1_bsnd_ne, src1->ne , GGML_MAX_DIMS * sizeof (int64_t ));
3195+ size_t src1_bsnd_nb[GGML_MAX_DIMS];
3196+ memcpy (src1_bsnd_nb, src1->nb , GGML_MAX_DIMS * sizeof (size_t ));
3197+ int64_t src2_bsnd_ne[GGML_MAX_DIMS];
3198+ memcpy (src2_bsnd_ne, src2->ne , GGML_MAX_DIMS * sizeof (int64_t ));
3199+ size_t src2_bsnd_nb[GGML_MAX_DIMS];
3200+ memcpy (src2_bsnd_nb, src2->nb , GGML_MAX_DIMS * sizeof (size_t ));
3201+
3202+ auto transpose12 = [](int64_t * ne, size_t * nb) {
3203+ int64_t ne_tmp = ne[1 ];
3204+ size_t nb_tmp = nb[1 ];
3205+ ne[1 ] = ne[2 ];
3206+ nb[1 ] = nb[2 ];
3207+ ne[2 ] = ne_tmp;
3208+ nb[2 ] = nb_tmp;
3209+ };
3210+
3211+ transpose12 (src0_bsnd_ne, src0_bsnd_nb);
3212+ transpose12 (src1_bsnd_ne, src1_bsnd_nb);
3213+ transpose12 (src2_bsnd_ne, src2_bsnd_nb);
3214+
31883215 float maxBias = 0 .0f ;
31893216 float scaleValue = 1 .0f ;
31903217 float logitSoftcap = 0 .0f ;
@@ -3206,11 +3233,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32063233 void * src0_f16_buffer = nullptr ;
32073234
32083235 if (ggml_cann_type_mapping (src0->type ) != faDataType){
3209- aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor (src0);
3236+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor (src0, src0_bsnd_ne,
3237+ src0_bsnd_nb, GGML_MAX_DIMS);
32103238 src0_f16_buffer = src0_f16_allocator.alloc (
32113239 ggml_nelements (src0) * faElemSize);
32123240
3213- int64_t * src0_f16_ne = src0-> ne ;
3241+ int64_t * src0_f16_ne = src0_bsnd_ne ;
32143242 size_t src0_f16_nb[GGML_MAX_DIMS];
32153243 src0_f16_nb[0 ] = sizeof (uint16_t );
32163244 for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
@@ -3224,20 +3252,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32243252 aclnn_cast (ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
32253253 ggml_cann_release_resources (ctx, acl_src0_f32_tensor);
32263254 }else {
3227- acl_src0_f16_tensor = ggml_cann_create_tensor (src0);
3255+ acl_src0_f16_tensor = ggml_cann_create_tensor (src0, src0_bsnd_ne,
3256+ src0_bsnd_nb, GGML_MAX_DIMS);
32283257 }
32293258
32303259 // Step 2: create the acl tensors for src1 (Key), src2 (Value),
32313260 // and the direct output from FusedInferAttention
32323261
3233- acl_src1_f16_tensor = ggml_cann_create_tensor (src1);
3234- acl_src2_f16_tensor = ggml_cann_create_tensor (src2);
3262+ acl_src1_f16_tensor = ggml_cann_create_tensor (src1, src1_bsnd_ne,
3263+ src1_bsnd_nb, GGML_MAX_DIMS);
3264+ acl_src2_f16_tensor = ggml_cann_create_tensor (src2, src2_bsnd_ne,
3265+ src2_bsnd_nb, GGML_MAX_DIMS);
32353266
32363267 ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
32373268 void * out_f16_buffer = out_f16_allocator.alloc (
32383269 ggml_nelements (dst) * faElemSize);
32393270
3240- int64_t * out_f16_ne = src0-> ne ;
3271+ int64_t * out_f16_ne = src0_bsnd_ne ;
32413272 size_t out_f16_nb[GGML_MAX_DIMS];
32423273 out_f16_nb[0 ] = faElemSize;
32433274 for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
@@ -3342,7 +3373,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33423373 // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
33433374 int64_t preTokens = 65535 ;
33443375 int64_t nextTokens = 65535 ;
3345- char layout[5 ] = {' B' , ' N ' , ' S ' , ' D' , 0 };
3376+ char layout[5 ] = {' B' , ' S ' , ' N ' , ' D' , 0 };
33463377 int64_t sparseMode = 0 ;
33473378 int64_t innerPrecise = (src0->ne [1 ] == 1 ) ? 0 : 2 ;
33483379 int64_t blockSize = 0 ;
@@ -3379,32 +3410,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33793410 );
33803411
33813412 // Step 6: post-processing, permute and cast to f32
3382-
3383- int64_t new_dim[] = {0 , 2 , 1 , 3 };
33843413 aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3385-
3386- if (ggml_cann_type_mapping (dst->type ) != faDataType){
3387- ggml_cann_pool_alloc perm_out_f16_allocator (ctx.pool ());
3388- perm_out_f16_allocator.alloc (ggml_nelements (dst) * faElemSize);
3389- void * perm_out_f16_buffer = perm_out_f16_allocator.get ();
3390-
3391- int64_t * perm_out_f16_ne = dst->ne ;
3392- size_t perm_out_f16_nb[GGML_MAX_DIMS];
3393- perm_out_f16_nb[0 ] = faElemSize;
3394- for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
3395- perm_out_f16_nb[i] = perm_out_f16_nb[i - 1 ] * perm_out_f16_ne[i - 1 ];
3396- }
3397- aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor (
3398- perm_out_f16_buffer, faDataType, faElemSize,
3399- perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
3400- aclnn_permute (ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
3401- aclnn_cast (ctx,
3402- acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
3403- ggml_cann_release_resources (ctx, acl_perm_out_f16_tensor);
3404- }else {
3405- // only need to permute
3406- aclnn_permute (ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
3407- }
3414+ // TODO: when dst is fp16, don't need cast
3415+ aclnn_cast (ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
34083416 ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
34093417 acl_src1_f16_tensor,
34103418 acl_src2_f16_tensor,
0 commit comments