From 7eb14d5a6bdc428d25fa848c785f0775dc408876 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 19 Apr 2024 13:50:17 -0400 Subject: [PATCH 1/2] Fix flash-attn for AMD --- ggml-cuda/common.cuh | 148 +++++++++++++++++++++---------------------- ggml-cuda/fattn.cu | 42 +++++++----- 2 files changed, 99 insertions(+), 91 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 156eba6d1ef74..e742cd3c2aca9 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -232,80 +232,6 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif //GGML_CUDA_F16 -[[noreturn]] -static __device__ void no_device_code( - const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n", - file_name, line, function_name, arch); - GGML_UNUSED(arch_list); -#else - printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n", - file_name, line, function_name, arch, arch_list); -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - __trap(); - - GGML_UNUSED(no_device_code); // suppress unused function warning -} - -#ifdef __CUDA_ARCH__ -#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__)) -#else -#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.") -#endif // __CUDA_ARCH__ - -static __device__ __forceinline__ float warp_reduce_sum(float x) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask, 32); - } - return x; -} - -static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); - a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); - } - return a; -} - -static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); - } - return a; -#else - GGML_UNUSED(a); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -} - -static __device__ __forceinline__ float warp_reduce_max(float x) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); - } - return x; -} - -static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); - } - return x; -#else - GGML_UNUSED(x); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -} - #if CUDART_VERSION < 12000 static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); @@ -397,6 +323,80 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { } #endif // defined(GGML_USE_HIPBLAS) +#ifdef __CUDA_ARCH__ +#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__)) +#else +#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.") +#endif // __CUDA_ARCH__ + +[[noreturn]] +static __device__ void no_device_code( + const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n", + file_name, line, function_name, arch); + GGML_UNUSED(arch_list); +#else + printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n", + file_name, line, function_name, arch, arch_list); +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + __trap(); + + GGML_UNUSED(no_device_code); // suppress unused function warning +} + +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); + a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + } + return a; +} + +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#else + GGML_UNUSED(a); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} + +static __device__ __forceinline__ float warp_reduce_max(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +} + +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + #define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index df1e80068b334..1e5f4410bb0c6 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -4,8 +4,14 @@ #include #if FP16_MMA_AVAILABLE +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#include +namespace wmma = rocwmma; +#else #include -#endif +namespace wmma = nvcuda::wmma; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // FP16_MMA_AVAILABLE #define FATTN_KQ_STRIDE 256 #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. @@ -231,11 +237,11 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); - typedef nvcuda::wmma::fragment frag_a_K; - typedef nvcuda::wmma::fragment frag_a_V; - typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c_KQ; - typedef nvcuda::wmma::fragment frag_c_VKQ; + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -319,7 +325,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); } } @@ -333,20 +339,20 @@ static __global__ void flash_attn_ext_f16( frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + wmma::fill_fragment(KQ_c[j], 0.0f); } #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major); } } @@ -452,7 +458,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - nvcuda::wmma::load_matrix_sync( + wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], KQ + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); @@ -464,7 +470,7 @@ static __global__ void flash_attn_ext_f16( for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); } #pragma unroll @@ -472,10 +478,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } @@ -487,10 +493,10 @@ static __global__ void flash_attn_ext_f16( for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync( + wmma::store_matrix_sync( KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); + D_padded, wmma::mem_col_major); } } @@ -863,6 +869,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) // 32x8 tensor cores are not available on AMD. if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { constexpr int cols_per_block = 8; constexpr int nwarps = 4; @@ -885,6 +892,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } return; } +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) if (Q->ne[1] <= 32) { constexpr int cols_per_block = 16; From 3e560c8665d4ea627be920a26da6d83811fde3b4 Mon Sep 17 00:00:00 2001 From: Jerome Date: Sat, 20 Apr 2024 11:06:03 -0400 Subject: [PATCH 2/2] Fix flashattn --- CMakeLists.txt | 2 +- ggml-cuda/common.cuh | 9 ++------- ggml-cuda/fattn.cu | 10 ++++++++-- llama.cpp | 7 ------- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 477c5b57c20e7..d005244077663 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1175,7 +1175,7 @@ add_library(ggml OBJECT ) target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES}) -target_compile_features (ggml PUBLIC c_std_11) # don't bump +target_compile_features (ggml PUBLIC cxx_std_17) # don't bump target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index e742cd3c2aca9..0069b9e52f786 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -364,16 +364,11 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { } static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); } return a; -#else - GGML_UNUSED(a); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } static __device__ __forceinline__ float warp_reduce_max(float x) { @@ -399,8 +394,8 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL - -#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 1e5f4410bb0c6..a605ed5665f00 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -7,6 +7,12 @@ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #include namespace wmma = rocwmma; +inline __device__ __half2 __hmax2(__half2 x, __half2 y) { + return __half2_raw{ + {{__hmax(__half2_raw(x).x, __half2_raw(y).x), + __hmax(__half2_raw(x).y, __half2_raw(y).y)}} + }; +} #else #include namespace wmma = nvcuda::wmma; @@ -339,7 +345,7 @@ static __global__ void flash_attn_ext_f16( frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - wmma::fill_fragment(KQ_c[j], 0.0f); + wmma::fill_fragment(KQ_c[j], KQ_acc_t{0.0f}); } #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { @@ -470,7 +476,7 @@ static __global__ void flash_attn_ext_f16( for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], __half{0.0f}); } #pragma unroll diff --git a/llama.cpp b/llama.cpp index 18d6297ce1dfd..04c5d770676cd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15473,13 +15473,6 @@ struct llama_context * llama_new_context_with_model( cparams.flash_attn = false; } -#ifdef GGML_USE_HIPBLAS - if (cparams.flash_attn) { - LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__); - cparams.flash_attn = false; - } -#endif - if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); }