Skip to content

Commit fb085fa

Browse files
committed
cuda : fix to F16 scalars + tune warps for RTX 2060
1 parent 2c04bee commit fb085fa

File tree

2 files changed

+61
-47
lines changed

2 files changed

+61
-47
lines changed

ggml-cuda.cu

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6491,8 +6491,8 @@ static __global__ void flash_attn_ext_f16(
64916491
__syncthreads();
64926492

64936493
{
6494-
float S[Q];
6495-
float M[Q];
6494+
half S[Q];
6495+
half M[Q];
64966496

64976497
for(int i = 0; i < Q; i++) {
64986498
S[i] = 0.0f;
@@ -6579,67 +6579,68 @@ static __global__ void flash_attn_ext_f16(
65796579
}
65806580

65816581
// used to detect blocks full of -INF
6582-
float smax = -INFINITY;
6582+
half smax = -INFINITY;
65836583

65846584
// online softmax
65856585
if (C == 32) {
65866586
for (int64_t j = 0; j < Q; ++j) {
65876587
const int64_t p = lane_id;
65886588

6589-
const float m = M[j];
6590-
const float s = __half2float(ss[j*T + p]);
6589+
const half m = M[j];
6590+
const half s = ss[j*T + p];
65916591

6592-
smax = warp_reduce_max(max(smax, s));
6593-
M[j] = warp_reduce_max(max(M[j], s));
6592+
smax = warp_reduce_max(__hmax(smax, s));
6593+
M[j] = warp_reduce_max(__hmax(M[j], s));
65946594

6595-
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
6596-
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
6595+
const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]);
6596+
const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]);
65976597

65986598
S[j] = S[j]*ms + warp_reduce_sum(vs);
65996599

66006600
// create a QxQ diagonal matrix for rescaling the output
66016601
if (p == j) {
6602-
ss[j*T + C + j] = __float2half(ms);
6602+
ss[j*T + C + j] = ms;
66036603
}
66046604

66056605
// the P matrix from the paper (Q rows, C columns)
6606-
ss[j*T + p] = __float2half(vs);
6606+
ss[j*T + p] = vs;
66076607
}
66086608
} else {
66096609
for (int64_t j = 0; j < Q; ++j) {
6610-
const float m = M[j];
6610+
const half m = M[j];
66116611

66126612
for (int64_t p = lane_id; p < C; p += NW) {
6613-
const float s = __half2float(ss[j*T + p]);
6613+
const half s = ss[j*T + p];
66146614

6615-
smax = warp_reduce_max(max(smax, s));
6616-
M[j] = warp_reduce_max(max(M[j], s));
6615+
smax = warp_reduce_max(__hmax(smax, s));
6616+
M[j] = warp_reduce_max(__hmax(M[j], s));
66176617
}
66186618

6619-
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
6619+
const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]);
66206620

66216621
S[j] = S[j]*ms;
66226622

66236623
// create a QxQ diagonal matrix for rescaling the output
66246624
if (lane_id == j) {
6625-
ss[j*T + C + j] = __float2half(ms);
6625+
ss[j*T + C + j] = ms;
66266626
}
66276627

66286628
for (int64_t p = lane_id; p < C; p += NW) {
6629-
const float s = __half2float(ss[j*T + p]);
6629+
const half s = ss[j*T + p];
66306630

6631-
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
6631+
const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]);
66326632

66336633
S[j] = S[j] + warp_reduce_sum(vs);
66346634

66356635
// the P matrix from the paper (Q rows, C columns)
6636-
ss[j*T + p] = __float2half(vs);
6636+
ss[j*T + p] = vs;
66376637
}
66386638
}
66396639
}
66406640

6641+
66416642
// skip -INF blocks
6642-
if (smax == -INFINITY) {
6643+
if (__hisinf(smax)) {
66436644
continue;
66446645
}
66456646

@@ -6686,16 +6687,16 @@ static __global__ void flash_attn_ext_f16(
66866687
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
66876688
for (int64_t j = 0; j < Q; ++j) {
66886689
if (lane_id == 0) {
6689-
ss[j*T + 0] = __float2half(S[j]);
6690-
ss[j*T + 1] = __float2half(M[j]);
6690+
ss[j*T + 0] = S[j];
6691+
ss[j*T + 1] = M[j];
66916692
}
66926693
}
66936694
}
66946695

66956696
// reduce the warps sequentially
66966697
for (int64_t sg = 1; sg < num_warps; ++sg) {
6697-
float S = 0.0f;
6698-
float M = -INFINITY;
6698+
half S = 0.0f;
6699+
half M = -INFINITY;
66996700

67006701
__syncthreads();
67016702

@@ -6713,25 +6714,25 @@ static __global__ void flash_attn_ext_f16(
67136714
// the first simdgroup accumulates the results from the other simdgroups
67146715
if (warp_id == 0) {
67156716
for (int64_t j = 0; j < Q; ++j) {
6716-
const float S0 = __half2float(ss[j*T + 0]);
6717-
const float S1 = __half2float(ss[j*T + sg*SH + 0]);
6717+
const half S0 = ss[j*T + 0];
6718+
const half S1 = ss[j*T + sg*SH + 0];
67186719

6719-
const float M0 = __half2float(ss[j*T + 1]);
6720-
const float M1 = __half2float(ss[j*T + sg*SH + 1]);
6720+
const half M0 = ss[j*T + 1];
6721+
const half M1 = ss[j*T + sg*SH + 1];
67216722

6722-
M = max(M0, M1);
6723+
M = __hmax(M0, M1);
67236724

6724-
const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M);
6725-
const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M);
6725+
const half ms0 = __hisinf(M0) ? 0.0f : expf(M0 - M);
6726+
const half ms1 = __hisinf(M1) ? 0.0f : expf(M1 - M);
67266727

67276728
S = S0*ms0 + S1*ms1;
67286729

67296730
if (lane_id == 0) {
6730-
ss[j*T + 0] = __float2half(S);
6731-
ss[j*T + 1] = __float2half(M);
6731+
ss[j*T + 0] = S;
6732+
ss[j*T + 1] = M;
67326733

6733-
ss[j*T + C + j ] = __float2half(ms0);
6734-
ss[j*T + C + j + sg*SH] = __float2half(ms1);
6734+
ss[j*T + C + j ] = ms0;
6735+
ss[j*T + C + j + sg*SH] = ms1;
67356736
}
67366737
}
67376738

@@ -6774,10 +6775,10 @@ static __global__ void flash_attn_ext_f16(
67746775
// final rescale with 1/S and store to global memory
67756776
if (warp_id == 0) {
67766777
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
6777-
const float S = __half2float(ss[j*T + 0]);
6778+
const half S = ss[j*T + 0];
67786779

67796780
for (int64_t i = lane_id; i < D; i += NW) {
6780-
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
6781+
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
67816782
}
67826783
}
67836784
}
@@ -10930,12 +10931,15 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1093010931
float scale;
1093110932
memcpy(&scale, KQV->op_params, sizeof(float));
1093210933

10933-
const int nqpb = 16; // queries per block
10934-
const int ncpw = 32; // cache values per warp (does not work for other values)
10934+
#define NQPB 16
10935+
#define NCPW 32
10936+
10937+
const int nqpb = NQPB; // queries per block
10938+
const int ncpw = NCPW; // cache values per warp (does not work for other values)
1093510939

1093610940
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
1093710941
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
10938-
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4;
10942+
const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2;
1093910943

1094010944
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
1094110945
dim3 block_dim(32, nwarps, 1);
@@ -10945,7 +10949,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1094510949
switch (Q->ne[0])
1094610950
{
1094710951
case 16:
10948-
flash_attn_ext_f16<16, 16, 32>
10952+
flash_attn_ext_f16<16, NQPB, NCPW>
1094910953
<<<blocks_num, block_dim, shmem, main_stream>>> (
1095010954
(const char *) src0_extra->data_device[g_main_device], // Query
1095110955
(const char *) src1_extra->data_device[g_main_device], // Key
@@ -10962,7 +10966,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1096210966
);
1096310967
break;
1096410968
case 64:
10965-
flash_attn_ext_f16<64, 16, 32>
10969+
flash_attn_ext_f16<64, NQPB, NCPW>
1096610970
<<<blocks_num, block_dim, shmem, main_stream>>> (
1096710971
(const char *) src0_extra->data_device[g_main_device], // Query
1096810972
(const char *) src1_extra->data_device[g_main_device], // Key
@@ -10979,7 +10983,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1097910983
);
1098010984
break;
1098110985
case 80:
10982-
flash_attn_ext_f16<80, 16, 32>
10986+
flash_attn_ext_f16<80, NQPB, NCPW>
1098310987
<<<blocks_num, block_dim, shmem, main_stream>>> (
1098410988
(const char *) src0_extra->data_device[g_main_device], // Query
1098510989
(const char *) src1_extra->data_device[g_main_device], // Key
@@ -10996,7 +11000,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1099611000
);
1099711001
break;
1099811002
case 128:
10999-
flash_attn_ext_f16<128, 16, 32>
11003+
flash_attn_ext_f16<128, NQPB, NCPW>
1100011004
<<<blocks_num, block_dim, shmem, main_stream>>> (
1100111005
(const char *) src0_extra->data_device[g_main_device], // Query
1100211006
(const char *) src1_extra->data_device[g_main_device], // Key

tests/test-backend-ops.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,19 @@ struct test_case {
572572
// duplicate the op
573573
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
574574
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
575+
#if 1
575576
for (int i = 1; i < n_runs; i++) {
576577
gf->nodes[gf->n_nodes++] = out;
577578
}
579+
#else
580+
n_runs = 1000;
581+
int n_nodes = gf->n_nodes;
582+
for (int i = 1; i < n_runs; i++) {
583+
for (int j = 0; j < n_nodes; j++) {
584+
gf->nodes[gf->n_nodes++] = gf->nodes[j];
585+
}
586+
}
587+
#endif
578588

579589
// calculate memory
580590
size_t mem = n_runs * op_size(out);
@@ -2199,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
21992209
test_cases.emplace_back(new test_pad());
22002210
test_cases.emplace_back(new test_leaky_relu());
22012211

2202-
#if 0
2203-
for (int hs : { 64, 80, 96, 112, 128, 256, }) {
2212+
#if 1
2213+
for (int hs : { 64, 80, 128, }) {
22042214
for (int nh : { 32, }) {
22052215
for (int kv : { 512, 1024, 2048, 4096, }) {
22062216
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {

0 commit comments

Comments
 (0)