Skip to content

Commit 29e01c3

Browse files
CUDA: add FP32 FlashAttention vector kernel
1 parent d46dbc7 commit 29e01c3

10 files changed

+873
-430
lines changed

ggml-cuda.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,6 +2979,16 @@ GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, si
29792979
CUDA_CHECK(cudaMemGetInfo(free, total));
29802980
}
29812981

2982+
GGML_CALL int ggml_backend_cuda_get_device_cc(int device) {
2983+
cudaDeviceProp prop;
2984+
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
2985+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2986+
return 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
2987+
#else
2988+
return 100*prop.major + 10*prop.minor;
2989+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2990+
}
2991+
29822992
GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
29832993
if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
29842994
return false;

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type
3434
GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
3535
GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
3636
GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
37+
GGML_API GGML_CALL int ggml_backend_cuda_get_device_cc(int device);
3738

3839
GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
3940
GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);

ggml-cuda/common.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
321321

322322
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
323323

324+
static bool fast_fp16_available(const int cc) {
325+
return cc >= CC_PASCAL && cc != 610;
326+
}
327+
324328
static bool fp16_mma_available(const int cc) {
325329
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
326330
}

ggml-cuda/fattn-common.cuh

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#define FATTN_KQ_STRIDE 256
2+
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
3+
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
4+
5+
template<int D, int parallel_blocks> // D == head size
6+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
7+
__launch_bounds__(D, 1)
8+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9+
static __global__ void flash_attn_combine_results(
10+
const float * __restrict__ VKQ_parts,
11+
const float2 * __restrict__ VKQ_meta,
12+
float * __restrict__ dst) {
13+
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
14+
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
15+
dst += D * gridDim.y*blockIdx.x;
16+
17+
const int tid = threadIdx.x;
18+
__builtin_assume(tid < D);
19+
20+
__shared__ float2 meta[parallel_blocks];
21+
if (tid < 2*parallel_blocks) {
22+
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
23+
}
24+
25+
__syncthreads();
26+
27+
float kqmax = meta[0].x;
28+
#pragma unroll
29+
for (int l = 1; l < parallel_blocks; ++l) {
30+
kqmax = max(kqmax, meta[l].x);
31+
}
32+
33+
float VKQ_numerator = 0.0f;
34+
float VKQ_denominator = 0.0f;
35+
#pragma unroll
36+
for (int l = 0; l < parallel_blocks; ++l) {
37+
const float diff = meta[l].x - kqmax;
38+
const float KQ_max_scale = expf(diff);
39+
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
40+
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
41+
42+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
43+
VKQ_denominator += KQ_max_scale * meta[l].y;
44+
}
45+
46+
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
47+
}

0 commit comments

Comments
 (0)