Skip to content

Commit 78ee06e

Browse files
partially revert changes
1 parent 57bde8c commit 78ee06e

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

ggml-cuda/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
429429
GGML_UNUSED(a);
430430
GGML_UNUSED(b);
431431
NO_DEVICE_CODE;
432-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
432+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
433433
}
434434

435435
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {

ggml-cuda/fattn.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ static __global__ void flash_attn_vec_ext_f16(
6161
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
6262
constexpr int nwarps = D / WARP_SIZE;
6363
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
64+
__builtin_assume(tid < D);
6465

6566
__shared__ half KQ[ncols*D];
6667
#pragma unroll
@@ -106,7 +107,10 @@ static __global__ void flash_attn_vec_ext_f16(
106107
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
107108
// Calculate KQ tile and keep track of new maximum KQ values:
108109
half kqmax_new[ncols];
109-
memcpy(kqmax_new, kqmax, sizeof(kqmax));
110+
#pragma unroll
111+
for (int j = 0; j < ncols; ++j) {
112+
kqmax_new[j] = kqmax[j];
113+
}
110114

111115
#pragma unroll
112116
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
@@ -123,7 +127,7 @@ static __global__ void flash_attn_vec_ext_f16(
123127

124128
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
125129
#pragma unroll
126-
for (int j = 0; j < ncols; ++j) {
130+
for (int j = 0; j < ncols; ++j) {
127131
sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
128132
}
129133
}

0 commit comments

Comments
 (0)