Skip to content

Commit db86df3

Browse files
committed
[CANN]: Optimization FA BNSD to BSND
Signed-off-by: noemotiovon <[email protected]>
1 parent 92e61dd commit db86df3

File tree

1 file changed

+43
-35
lines changed

1 file changed

+43
-35
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3180,11 +3180,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
31803180

31813181
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
31823182

3183-
ggml_tensor* src0 = dst->src[0]; // q, fp32
3184-
ggml_tensor* src1 = dst->src[1]; // k, fp16
3185-
ggml_tensor* src2 = dst->src[2]; // v, fp16
3183+
ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
3184+
ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
3185+
ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
31863186
ggml_tensor* src3 = dst->src[3]; // mask, fp16
31873187

3188+
// B, N, S, D (uncont) -> B, S, N, D (cont)
3189+
int64_t src0_bsnd_ne[GGML_MAX_DIMS];
3190+
memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t));
3191+
size_t src0_bsnd_nb[GGML_MAX_DIMS];
3192+
memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t));
3193+
int64_t src1_bsnd_ne[GGML_MAX_DIMS];
3194+
memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t));
3195+
size_t src1_bsnd_nb[GGML_MAX_DIMS];
3196+
memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t));
3197+
int64_t src2_bsnd_ne[GGML_MAX_DIMS];
3198+
memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t));
3199+
size_t src2_bsnd_nb[GGML_MAX_DIMS];
3200+
memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t));
3201+
3202+
auto transpose12 = [](int64_t* ne, size_t* nb) {
3203+
int64_t ne_tmp = ne[1];
3204+
size_t nb_tmp = nb[1];
3205+
ne[1] = ne[2];
3206+
nb[1] = nb[2];
3207+
ne[2] = ne_tmp;
3208+
nb[2] = nb_tmp;
3209+
};
3210+
3211+
transpose12(src0_bsnd_ne, src0_bsnd_nb);
3212+
transpose12(src1_bsnd_ne, src1_bsnd_nb);
3213+
transpose12(src2_bsnd_ne, src2_bsnd_nb);
3214+
31883215
float maxBias = 0.0f;
31893216
float scaleValue = 1.0f;
31903217
float logitSoftcap = 0.0f;
@@ -3206,11 +3233,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32063233
void* src0_f16_buffer = nullptr;
32073234

32083235
if(ggml_cann_type_mapping(src0->type) != faDataType){
3209-
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
3236+
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
3237+
src0_bsnd_nb, GGML_MAX_DIMS);
32103238
src0_f16_buffer = src0_f16_allocator.alloc(
32113239
ggml_nelements(src0) * faElemSize);
32123240

3213-
int64_t* src0_f16_ne = src0->ne;
3241+
int64_t* src0_f16_ne = src0_bsnd_ne;
32143242
size_t src0_f16_nb[GGML_MAX_DIMS];
32153243
src0_f16_nb[0] = sizeof(uint16_t);
32163244
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3224,20 +3252,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32243252
aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
32253253
ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
32263254
}else{
3227-
acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
3255+
acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
3256+
src0_bsnd_nb, GGML_MAX_DIMS);
32283257
}
32293258

32303259
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
32313260
// and the direct output from FusedInferAttention
32323261

3233-
acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
3234-
acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
3262+
acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne,
3263+
src1_bsnd_nb, GGML_MAX_DIMS);
3264+
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
3265+
src2_bsnd_nb, GGML_MAX_DIMS);
32353266

32363267
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
32373268
void* out_f16_buffer = out_f16_allocator.alloc(
32383269
ggml_nelements(dst) * faElemSize);
32393270

3240-
int64_t* out_f16_ne = src0->ne;
3271+
int64_t* out_f16_ne = src0_bsnd_ne;
32413272
size_t out_f16_nb[GGML_MAX_DIMS];
32423273
out_f16_nb[0] = faElemSize;
32433274
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3342,7 +3373,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33423373
// double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
33433374
int64_t preTokens = 65535;
33443375
int64_t nextTokens = 65535;
3345-
char layout[5] = {'B', 'N', 'S', 'D', 0};
3376+
char layout[5] = {'B', 'S', 'N', 'D', 0};
33463377
int64_t sparseMode = 0;
33473378
int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
33483379
int64_t blockSize = 0;
@@ -3379,32 +3410,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33793410
);
33803411

33813412
// Step 6: post-processing, permute and cast to f32
3382-
3383-
int64_t new_dim[] = {0, 2, 1, 3};
33843413
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
3385-
3386-
if(ggml_cann_type_mapping(dst->type) != faDataType){
3387-
ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
3388-
perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
3389-
void* perm_out_f16_buffer = perm_out_f16_allocator.get();
3390-
3391-
int64_t* perm_out_f16_ne = dst->ne;
3392-
size_t perm_out_f16_nb[GGML_MAX_DIMS];
3393-
perm_out_f16_nb[0] = faElemSize;
3394-
for(int i = 1; i < GGML_MAX_DIMS; ++i){
3395-
perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
3396-
}
3397-
aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
3398-
perm_out_f16_buffer, faDataType, faElemSize,
3399-
perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
3400-
aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
3401-
aclnn_cast(ctx,
3402-
acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
3403-
ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
3404-
}else{
3405-
// only need to permute
3406-
aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
3407-
}
3414+
// TODO: when dst is fp16, don't need cast
3415+
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
34083416
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
34093417
acl_src1_f16_tensor,
34103418
acl_src2_f16_tensor,

0 commit comments

Comments
 (0)